diff options
Diffstat (limited to 'src/api/cvc4cpp.cpp')
-rw-r--r-- | src/api/cvc4cpp.cpp | 120 |
1 files changed, 76 insertions, 44 deletions
diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index 5c0c4a750..d990b3e22 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -2241,39 +2241,40 @@ Grammar::Grammar(const Solver* s, : d_s(s), d_sygusVars(sygusVars), d_ntSyms(ntSymbols), - d_ntsToUnres(), - d_dtDecls(), - d_allowConst() + d_ntsToTerms(ntSymbols.size()), + d_allowConst(), + d_allowVars(), + d_isResolved(false) { for (Term ntsymbol : d_ntSyms) { - // make the datatype, which encodes terms generated by this non-terminal - d_dtDecls.emplace(ntsymbol, DatatypeDecl(d_s, ntsymbol.toString())); - // make its unresolved type, used for referencing the final version of - // the datatype - d_ntsToUnres[ntsymbol] = d_s->getExprManager()->mkSort(ntsymbol.toString()); + d_ntsToTerms.emplace(ntsymbol, std::vector<Term>()); } } void Grammar::addRule(Term ntSymbol, Term rule) { + CVC4_API_CHECK(!d_isResolved) << "Grammar cannot be modified after passing " + "it as an argument to synthFun/synthInv"; CVC4_API_ARG_CHECK_NOT_NULL(ntSymbol); CVC4_API_ARG_CHECK_NOT_NULL(rule); - CVC4_API_ARG_CHECK_EXPECTED(d_dtDecls.find(ntSymbol) != d_dtDecls.end(), - ntSymbol) + CVC4_API_ARG_CHECK_EXPECTED( + d_ntsToTerms.find(ntSymbol) != d_ntsToTerms.cend(), ntSymbol) << "ntSymbol to be one of the non-terminal symbols given in the " "predeclaration"; CVC4_API_CHECK(ntSymbol.d_expr->getType() == rule.d_expr->getType()) << "Expected ntSymbol and rule to have the same sort"; - addSygusConstructorTerm(d_dtDecls[ntSymbol], rule); + d_ntsToTerms[ntSymbol].push_back(rule); } void Grammar::addRules(Term ntSymbol, std::vector<Term> rules) { + CVC4_API_CHECK(!d_isResolved) << "Grammar cannot be modified after passing " + "it as an argument to synthFun/synthInv"; CVC4_API_ARG_CHECK_NOT_NULL(ntSymbol); - CVC4_API_ARG_CHECK_EXPECTED(d_dtDecls.find(ntSymbol) != d_dtDecls.end(), - ntSymbol) + CVC4_API_ARG_CHECK_EXPECTED( + d_ntsToTerms.find(ntSymbol) != d_ntsToTerms.cend(), ntSymbol) << "ntSymbol to be one of the non-terminal symbols given in the " "predeclaration"; @@ -2285,16 +2286,19 @@ void Grammar::addRules(Term ntSymbol, std::vector<Term> rules) CVC4_API_CHECK(ntSymbol.d_expr->getType() == rules[i].d_expr->getType()) << "Expected ntSymbol and rule at index " << i << " to have the same sort"; - - addSygusConstructorTerm(d_dtDecls[ntSymbol], rules[i]); } + + d_ntsToTerms[ntSymbol].insert( + d_ntsToTerms[ntSymbol].cend(), rules.cbegin(), rules.cend()); } void Grammar::addAnyConstant(Term ntSymbol) { + CVC4_API_CHECK(!d_isResolved) << "Grammar cannot be modified after passing " + "it as an argument to synthFun/synthInv"; CVC4_API_ARG_CHECK_NOT_NULL(ntSymbol); - CVC4_API_ARG_CHECK_EXPECTED(d_dtDecls.find(ntSymbol) != d_dtDecls.end(), - ntSymbol) + CVC4_API_ARG_CHECK_EXPECTED( + d_ntsToTerms.find(ntSymbol) != d_ntsToTerms.cend(), ntSymbol) << "ntSymbol to be one of the non-terminal symbols given in the " "predeclaration"; @@ -2303,17 +2307,21 @@ void Grammar::addAnyConstant(Term ntSymbol) void Grammar::addAnyVariable(Term ntSymbol) { + CVC4_API_CHECK(!d_isResolved) << "Grammar cannot be modified after passing " + "it as an argument to synthFun/synthInv"; CVC4_API_ARG_CHECK_NOT_NULL(ntSymbol); - CVC4_API_ARG_CHECK_EXPECTED(d_dtDecls.find(ntSymbol) != d_dtDecls.end(), - ntSymbol) + CVC4_API_ARG_CHECK_EXPECTED( + d_ntsToTerms.find(ntSymbol) != d_ntsToTerms.cend(), ntSymbol) << "ntSymbol to be one of the non-terminal symbols given in the " "predeclaration"; - addSygusConstructorVariables(d_dtDecls[ntSymbol], ntSymbol.d_expr->getType()); + d_allowVars.insert(ntSymbol); } Sort Grammar::resolve() { + d_isResolved = true; + Term bvl; if (!d_sygusVars.empty()) @@ -2322,29 +2330,48 @@ Sort Grammar::resolve() termVectorToExprs(d_sygusVars)); } - for (const Term& i : d_ntSyms) + std::unordered_map<Term, Sort, TermHashFunction> ntsToUnres(d_ntSyms.size()); + + for (Term ntsymbol : d_ntSyms) { - bool aci = d_allowConst.find(i) != d_allowConst.end(); - Type btt = i.d_expr->getType(); - d_dtDecls[i].d_dtype->setSygus(btt, *bvl.d_expr, aci, false); - // We can be in a case where the only rule specified was (Variable T) - // and there are no variables of type T, in which case this is a bogus - // grammar. This results in the error below. - CVC4_API_CHECK(d_dtDecls[i].d_dtype->getNumConstructors() != 0) - << "Grouped rule listing for " << d_dtDecls[i] - << " produced an empty rule list"; + // make the unresolved type, used for referencing the final version of + // the ntsymbol's datatype + ntsToUnres[ntsymbol] = d_s->getExprManager()->mkSort(ntsymbol.toString()); } - // now, make the sygus datatype std::vector<CVC4::Datatype> datatypes; std::set<Type> unresTypes; datatypes.reserve(d_ntSyms.size()); - for (const Term& i : d_ntSyms) + for (const Term& ntSym : d_ntSyms) { - datatypes.push_back(*d_dtDecls[i].d_dtype); - unresTypes.insert(*d_ntsToUnres[i].d_type); + // make the datatype, which encodes terms generated by this non-terminal + DatatypeDecl dtDecl(d_s, ntSym.toString()); + + for (const Term& consTerm : d_ntsToTerms[ntSym]) + { + addSygusConstructorTerm(dtDecl, consTerm, ntsToUnres); + } + + if (d_allowVars.find(ntSym) != d_allowConst.cend()) + { + addSygusConstructorVariables(dtDecl, ntSym.d_expr->getType()); + } + + bool aci = d_allowConst.find(ntSym) != d_allowConst.end(); + Type btt = ntSym.d_expr->getType(); + dtDecl.d_dtype->setSygus(btt, *bvl.d_expr, aci, false); + + // We can be in a case where the only rule specified was (Variable T) + // and there are no variables of type T, in which case this is a bogus + // grammar. This results in the error below. + CVC4_API_CHECK(dtDecl.d_dtype->getNumConstructors() != 0) + << "Grouped rule listing for " << *dtDecl.d_dtype + << " produced an empty rule list"; + + datatypes.push_back(*dtDecl.d_dtype); + unresTypes.insert(*ntsToUnres[ntSym].d_type); } std::vector<DatatypeType> datatypeTypes = @@ -2355,7 +2382,10 @@ Sort Grammar::resolve() return datatypeTypes[0]; } -void Grammar::addSygusConstructorTerm(DatatypeDecl& dt, Term term) const +void Grammar::addSygusConstructorTerm( + DatatypeDecl& dt, + Term term, + const std::unordered_map<Term, Sort, TermHashFunction>& ntsToUnres) const { // At this point, we should know that dt is well founded, and that its // builtin sygus operators are well-typed. @@ -2367,7 +2397,7 @@ void Grammar::addSygusConstructorTerm(DatatypeDecl& dt, Term term) const // this does not lead to exponential behavior with respect to input size. std::vector<Term> args; std::vector<Sort> cargs; - Term op = purifySygusGTerm(term, args, cargs); + Term op = purifySygusGTerm(term, args, cargs, ntsToUnres); std::stringstream ssCName; ssCName << op.getKind(); std::shared_ptr<SygusPrintCallback> spc; @@ -2386,13 +2416,15 @@ void Grammar::addSygusConstructorTerm(DatatypeDecl& dt, Term term) const *op.d_expr, ssCName.str(), sortVectorToTypes(cargs), spc); } -Term Grammar::purifySygusGTerm(Term term, - std::vector<Term>& args, - std::vector<Sort>& cargs) const +Term Grammar::purifySygusGTerm( + Term term, + std::vector<Term>& args, + std::vector<Sort>& cargs, + const std::unordered_map<Term, Sort, TermHashFunction>& ntsToUnres) const { std::unordered_map<Term, Sort, TermHashFunction>::const_iterator itn = - d_ntsToUnres.find(term); - if (itn != d_ntsToUnres.cend()) + ntsToUnres.find(term); + if (itn != ntsToUnres.cend()) { Term ret = d_s->getExprManager()->mkBoundVar(term.d_expr->getType()); args.push_back(ret); @@ -2403,7 +2435,7 @@ Term Grammar::purifySygusGTerm(Term term, bool childChanged = false; for (unsigned i = 0, nchild = term.d_expr->getNumChildren(); i < nchild; i++) { - Term ptermc = purifySygusGTerm((*term.d_expr)[i], args, cargs); + Term ptermc = purifySygusGTerm((*term.d_expr)[i], args, cargs, ntsToUnres); pchildren.push_back(ptermc); childChanged = childChanged || *ptermc.d_expr != (*term.d_expr)[i]; } @@ -4495,7 +4527,7 @@ Term Solver::synthFun(const std::string& symbol, Term Solver::synthFun(const std::string& symbol, const std::vector<Term>& boundVars, Sort sort, - Grammar g) const + Grammar& g) const { return synthFunHelper(symbol, boundVars, sort, false, &g); } @@ -4508,7 +4540,7 @@ Term Solver::synthInv(const std::string& symbol, Term Solver::synthInv(const std::string& symbol, const std::vector<Term>& boundVars, - Grammar g) const + Grammar& g) const { return synthFunHelper(symbol, boundVars, d_exprMgr->booleanType(), true, &g); } |