diff options
-rw-r--r-- | src/api/cvc4cpp.cpp | 3 | ||||
-rw-r--r-- | src/parser/smt2/Smt2.g | 33 | ||||
-rw-r--r-- | src/parser/smt2/smt2.cpp | 44 | ||||
-rw-r--r-- | src/parser/smt2/smt2.h | 43 | ||||
-rw-r--r-- | src/smt/command.cpp | 33 | ||||
-rw-r--r-- | src/smt/smt_engine.cpp | 2 | ||||
-rw-r--r-- | src/theory/quantifiers/sygus/sygus_interpol.cpp | 17 | ||||
-rw-r--r-- | test/unit/api/solver_black.h | 2 |
8 files changed, 38 insertions, 139 deletions
diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index bda88c539..3cfeaf6cd 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -5591,9 +5591,6 @@ Term Solver::synthFunHelper(const std::string& symbol, CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_NOT_NULL(sort); - CVC4_API_ARG_CHECK_EXPECTED(sort.d_type->isFirstClass(), sort) - << "first-class codomain sort for function"; - std::vector<Type> varTypes; for (size_t i = 0, n = boundVars.size(); i < n; ++i) { diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 232723fc0..7c1c5dc3e 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -537,12 +537,12 @@ command [std::unique_ptr<CVC4::Command>* cmd] sygusCommand returns [std::unique_ptr<CVC4::Command> cmd] @declarations { - CVC4::api::Term expr, expr2; + CVC4::api::Term expr, expr2, fun; CVC4::api::Sort t, range; std::vector<std::string> names; std::vector<std::pair<std::string, CVC4::api::Sort> > sortedVarNames; - std::unique_ptr<Smt2::SynthFunFactory> synthFunFactory; - std::string name, fun; + std::vector<CVC4::api::Term> sygusVars; + std::string name; bool isInv; CVC4::api::Grammar* grammar = nullptr; } @@ -552,7 +552,8 @@ sygusCommand returns [std::unique_ptr<CVC4::Command> cmd] { PARSER_STATE->checkUserSymbol(name); } sortSymbol[t,CHECK_DECLARED] { - api::Term var = PARSER_STATE->bindBoundVar(name, t); + api::Term var = SOLVER->mkSygusVar(t, name); + PARSER_STATE->defineVar(name, var); cmd.reset(new DeclareSygusVarCommand(name, var, t)); } | /* synth-fun */ @@ -560,22 +561,36 @@ sygusCommand returns [std::unique_ptr<CVC4::Command> cmd] | SYNTH_INV_TOK { isInv = true; range = SOLVER->getBooleanSort(); } ) { PARSER_STATE->checkThatLogicIsSet(); } - symbol[fun,CHECK_UNDECLARED,SYM_VARIABLE] + symbol[name,CHECK_UNDECLARED,SYM_VARIABLE] LPAREN_TOK sortedVarList[sortedVarNames] RPAREN_TOK ( sortSymbol[range,CHECK_DECLARED] )? { - synthFunFactory.reset(new Smt2::SynthFunFactory( - PARSER_STATE, fun, isInv, range, sortedVarNames)); + PARSER_STATE->pushScope(true); + sygusVars = PARSER_STATE->bindBoundVars(sortedVarNames); } ( // optionally, read the sygus grammar // // `grammar` specifies the required grammar for the function to // synthesize, expressed as a type - sygusGrammar[grammar, synthFunFactory->getSygusVars(), fun] + sygusGrammar[grammar, sygusVars, name] )? { - cmd = synthFunFactory->mkCommand(grammar); + Debug("parser-sygus") << "Define synth fun : " << name << std::endl; + + fun = isInv ? (grammar == nullptr + ? SOLVER->synthInv(name, sygusVars) + : SOLVER->synthInv(name, sygusVars, *grammar)) + : (grammar == nullptr + ? SOLVER->synthFun(name, sygusVars, range) + : SOLVER->synthFun(name, sygusVars, range, *grammar)); + + Debug("parser-sygus") << "...read synth fun " << name << std::endl; + PARSER_STATE->popScope(); + // we do not allow overloading for synth fun + PARSER_STATE->defineVar(name, fun); + cmd = std::unique_ptr<Command>( + new SynthFunCommand(name, fun, sygusVars, range, isInv, grammar)); } | /* constraint */ CONSTRAINT_TOK { diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 81a4bd4a6..629164593 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -482,50 +482,6 @@ void Smt2::resetAssertions() { pushScope(true); } -Smt2::SynthFunFactory::SynthFunFactory( - Smt2* smt2, - const std::string& id, - bool isInv, - api::Sort range, - std::vector<std::pair<std::string, api::Sort>>& sortedVarNames) - : d_smt2(smt2), d_id(id), d_sort(range), d_isInv(isInv) -{ - if (range.isNull()) - { - smt2->parseError("Must supply return type for synth-fun."); - } - if (range.isFunction()) - { - smt2->parseError("Cannot use synth-fun with function return type."); - } - - std::vector<api::Sort> varSorts; - for (const std::pair<std::string, api::Sort>& p : sortedVarNames) - { - varSorts.push_back(p.second); - } - - api::Sort funSort = varSorts.empty() - ? range - : d_smt2->d_solver->mkFunctionSort(varSorts, range); - - // we do not allow overloading for synth fun - d_fun = d_smt2->bindBoundVar(id, funSort); - - Debug("parser-sygus") << "Define synth fun : " << id << std::endl; - - d_smt2->pushScope(true); - d_sygusVars = d_smt2->bindBoundVars(sortedVarNames); -} - -std::unique_ptr<Command> Smt2::SynthFunFactory::mkCommand(api::Grammar* grammar) -{ - Debug("parser-sygus") << "...read synth fun " << d_id << std::endl; - d_smt2->popScope(); - return std::unique_ptr<Command>( - new SynthFunCommand(d_id, d_fun, d_sygusVars, d_sort, d_isInv, grammar)); -} - std::unique_ptr<Command> Smt2::invConstraint( const std::vector<std::string>& names) { diff --git a/src/parser/smt2/smt2.h b/src/parser/smt2/smt2.h index 5fcf49637..1aa0ebac7 100644 --- a/src/parser/smt2/smt2.h +++ b/src/parser/smt2/smt2.h @@ -195,49 +195,6 @@ class Smt2 : public Parser void resetAssertions(); /** - * Class for creating instances of `SynthFunCommand`s. Creating an instance - * of this class pushes the scope, destroying it pops the scope. - */ - class SynthFunFactory - { - public: - /** - * Creates an instance of `SynthFunFactory`. - * - * @param smt2 Pointer to the parser state - * @param id Name of the function to synthesize - * @param isInv True if the goal is to synthesize an invariant, false - * otherwise - * @param range The return type of the function-to-synthesize - * @param sortedVarNames The parameters of the function-to-synthesize - */ - SynthFunFactory( - Smt2* smt2, - const std::string& id, - bool isInv, - api::Sort range, - std::vector<std::pair<std::string, api::Sort>>& sortedVarNames); - - const std::vector<api::Term>& getSygusVars() const { return d_sygusVars; } - - /** - * Create an instance of `SynthFunCommand`. - * - * @param grammar Optional grammar associated with the synth-fun command - * @return The instance of `SynthFunCommand` - */ - std::unique_ptr<Command> mkCommand(api::Grammar* grammar); - - private: - Smt2* d_smt2; - std::string d_id; - api::Term d_fun; - api::Sort d_sort; - bool d_isInv; - std::vector<api::Term> d_sygusVars; - }; - - /** * Creates a command that adds an invariant constraint. * * @param names Name of four symbols corresponding to the diff --git a/src/smt/command.cpp b/src/smt/command.cpp index 9c45c0b19..eb03edf4f 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -581,16 +581,7 @@ api::Sort DeclareSygusVarCommand::getSort() const { return d_sort; } void DeclareSygusVarCommand::invoke(api::Solver* solver) { - try - { - solver->getSmtEngine()->declareSygusVar( - d_symbol, d_var.getNode(), TypeNode::fromType(d_sort.getType())); - d_commandStatus = CommandSuccess::instance(); - } - catch (exception& e) - { - d_commandStatus = new CommandFailure(e.what()); - } + d_commandStatus = CommandSuccess::instance(); } Command* DeclareSygusVarCommand::clone() const @@ -646,27 +637,7 @@ const api::Grammar* SynthFunCommand::getGrammar() const { return d_grammar; } void SynthFunCommand::invoke(api::Solver* solver) { - try - { - std::vector<Node> vns; - for (const api::Term& t : d_vars) - { - vns.push_back(Node::fromExpr(t.getExpr())); - } - solver->getSmtEngine()->declareSynthFun( - d_symbol, - Node::fromExpr(d_fun.getExpr()), - TypeNode::fromType(d_grammar == nullptr - ? d_sort.getType() - : d_grammar->resolve().getType()), - d_isInv, - vns); - d_commandStatus = CommandSuccess::instance(); - } - catch (exception& e) - { - d_commandStatus = new CommandFailure(e.what()); - } + d_commandStatus = CommandSuccess::instance(); } Command* SynthFunCommand::clone() const diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index f345bee2e..d0906ce98 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -1065,7 +1065,6 @@ Result SmtEngine::assertFormula(const Node& formula, bool inUnsatCore) void SmtEngine::declareSygusVar(const std::string& id, Node var, TypeNode type) { SmtScope smts(this); - finishInit(); d_sygusSolver->declareSygusVar(id, var, type); if (Dump.isOn("raw-benchmark")) { @@ -1082,7 +1081,6 @@ void SmtEngine::declareSynthFun(const std::string& id, const std::vector<Node>& vars) { SmtScope smts(this); - finishInit(); d_state->doPendingPops(); d_sygusSolver->declareSynthFun(id, func, sygusType, isInv, vars); diff --git a/src/theory/quantifiers/sygus/sygus_interpol.cpp b/src/theory/quantifiers/sygus/sygus_interpol.cpp index e4e7a02c7..d5ab0e51f 100644 --- a/src/theory/quantifiers/sygus/sygus_interpol.cpp +++ b/src/theory/quantifiers/sygus/sygus_interpol.cpp @@ -319,6 +319,18 @@ bool SygusInterpol::solveInterpolation(const std::string& name, const TypeNode& itpGType, Node& interpol) { + // Some instructions in setSynthGrammar and mkSygusConjecture need a fully + // initialized solver to work properly. Notice, however, that the sub-solver + // created below is not fully initialized by the time those two methods are + // needed. Therefore, we call them while the current parent solver is in scope + // (i.e., before creating the sub-solver). + collectSymbols(axioms, conj); + createVariables(itpGType.isNull()); + TypeNode grammarType = setSynthGrammar(itpGType, axioms, conj); + + Node itp = mkPredicate(name); + mkSygusConjecture(itp, axioms, conj); + std::unique_ptr<SmtEngine> subSolver; initializeSubsolver(subSolver); // get the logic @@ -327,17 +339,12 @@ bool SygusInterpol::solveInterpolation(const std::string& name, l.enableSygus(); subSolver->setLogic(l); - collectSymbols(axioms, conj); - createVariables(itpGType.isNull()); for (Node var : d_vars) { subSolver->declareSygusVar(name, var, var.getType()); } std::vector<Node> vars_empty; - TypeNode grammarType = setSynthGrammar(itpGType, axioms, conj); - Node itp = mkPredicate(name); subSolver->declareSynthFun(name, itp, grammarType, false, vars_empty); - mkSygusConjecture(itp, axioms, conj); Trace("sygus-interpol") << "SmtEngine::getInterpol: made conjecture : " << d_sygusConj << ", solving for " << d_sygusConj[0][0] << std::endl; diff --git a/test/unit/api/solver_black.h b/test/unit/api/solver_black.h index aa4289ef3..8b8c6dd58 100644 --- a/test/unit/api/solver_black.h +++ b/test/unit/api/solver_black.h @@ -2268,7 +2268,6 @@ void SolverBlack::testSynthFun() Sort null = d_solver->getNullSort(); Sort boolean = d_solver->getBooleanSort(); Sort integer = d_solver->getIntegerSort(); - Sort boolToBool = d_solver->mkFunctionSort(boolean, boolean); Term nullTerm; Term x = d_solver->mkVar(boolean); @@ -2289,7 +2288,6 @@ void SolverBlack::testSynthFun() TS_ASSERT_THROWS(d_solver->synthFun("f3", {nullTerm}, boolean), CVC4ApiException&); TS_ASSERT_THROWS(d_solver->synthFun("f4", {}, null), CVC4ApiException&); - TS_ASSERT_THROWS(d_solver->synthFun("f5", {}, boolToBool), CVC4ApiException&); TS_ASSERT_THROWS(d_solver->synthFun("f6", {x}, boolean, g2), CVC4ApiException&); Solver slv; |