diff options
author | Andres Noetzli <andres.noetzli@gmail.com> | 2021-07-28 14:11:55 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-28 14:11:55 -0700 |
commit | 1377cede4b223a5b6a68d7d9194b7e3346a2d51a (patch) | |
tree | face9b6f6b46f663a38115c6ca11fb7415acbd10 /src/theory/quantifiers | |
parent | 5067dee413caf5f5bda4e666d877841f936d74b0 (diff) | |
parent | e6747735d2074fc2651c5edc11fa8170fc13663e (diff) |
Merge branch 'master' into docsLinkdocsLink
Diffstat (limited to 'src/theory/quantifiers')
31 files changed, 499 insertions, 176 deletions
diff --git a/src/theory/quantifiers/candidate_rewrite_database.cpp b/src/theory/quantifiers/candidate_rewrite_database.cpp index 789a723b9..c2ee563e3 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.cpp +++ b/src/theory/quantifiers/candidate_rewrite_database.cpp @@ -20,6 +20,7 @@ #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" #include "smt/smt_statistics_registry.h" +#include "theory/datatypes/sygus_datatype_utils.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" @@ -244,8 +245,8 @@ Node CandidateRewriteDatabase::addTerm(Node sol, // wish to enumerate any term that contains sol (resp. eq_sol) // as a subterm. Node exc_sol = sol; - unsigned sz = d_tds->getSygusTermSize(sol); - unsigned eqsz = d_tds->getSygusTermSize(eq_sol); + unsigned sz = datatypes::utils::getSygusTermSize(sol); + unsigned eqsz = datatypes::utils::getSygusTermSize(eq_sol); if (eqsz > sz) { sz = eqsz; diff --git a/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp b/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp index f059767a6..f65828d2f 100644 --- a/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp +++ b/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp @@ -491,7 +491,7 @@ bool InstStrategyCegqi::doAddInstantiation( std::vector< Node >& subs ) { else if (inst->addInstantiation(d_curr_quant, subs, InferenceId::QUANTIFIERS_INST_CEGQI, - false, + Node::null(), false, usedVts)) { diff --git a/src/theory/quantifiers/ematching/trigger.cpp b/src/theory/quantifiers/ematching/trigger.cpp index 62558e2c6..529125978 100644 --- a/src/theory/quantifiers/ematching/trigger.cpp +++ b/src/theory/quantifiers/ematching/trigger.cpp @@ -59,6 +59,7 @@ Trigger::Trigger(QuantifiersState& qs, Node np = ensureGroundTermPreprocessed(val, n, d_groundTerms); d_nodes.push_back(np); } + d_trNode = NodeManager::currentNM()->mkNode(SEXPR, d_nodes); if (Trace.isOn("trigger")) { QuantAttributes& qa = d_qreg.getQuantAttributes(); @@ -163,7 +164,7 @@ uint64_t Trigger::addInstantiations() bool Trigger::sendInstantiation(std::vector<Node>& m, InferenceId id) { - return d_qim.getInstantiate()->addInstantiation(d_quant, m, id); + return d_qim.getInstantiate()->addInstantiation(d_quant, m, id, d_trNode); } bool Trigger::sendInstantiation(InstMatch& m, InferenceId id) diff --git a/src/theory/quantifiers/ematching/trigger.h b/src/theory/quantifiers/ematching/trigger.h index 172e93c12..944a082c0 100644 --- a/src/theory/quantifiers/ematching/trigger.h +++ b/src/theory/quantifiers/ematching/trigger.h @@ -181,6 +181,8 @@ class Trigger { std::vector<Node>& gts); /** The nodes comprising this trigger. */ std::vector<Node> d_nodes; + /** The nodes as a single s-expression */ + Node d_trNode; /** * The preprocessed ground terms in the nodes of the trigger, which as an * optimization omits variables and constant subterms. These terms are diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index aa7e183bb..58a78b4aa 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -48,7 +48,7 @@ ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr) d_false = NodeManager::currentNM()->mkConst(false); } -void ExtendedRewriter::setCache(Node n, Node ret) +void ExtendedRewriter::setCache(Node n, Node ret) const { if (d_aggr) { @@ -62,7 +62,7 @@ void ExtendedRewriter::setCache(Node n, Node ret) } } -Node ExtendedRewriter::getCache(Node n) +Node ExtendedRewriter::getCache(Node n) const { if (d_aggr) { @@ -83,7 +83,7 @@ Node ExtendedRewriter::getCache(Node n) bool ExtendedRewriter::addToChildren(Node nc, std::vector<Node>& children, - bool dropDup) + bool dropDup) const { // If the operator is non-additive, do not consider duplicates if (dropDup @@ -95,7 +95,7 @@ bool ExtendedRewriter::addToChildren(Node nc, return true; } -Node ExtendedRewriter::extendedRewrite(Node n) +Node ExtendedRewriter::extendedRewrite(Node n) const { n = Rewriter::rewrite(n); @@ -280,7 +280,7 @@ Node ExtendedRewriter::extendedRewrite(Node n) return ret; } -Node ExtendedRewriter::extendedRewriteAggr(Node n) +Node ExtendedRewriter::extendedRewriteAggr(Node n) const { Node new_ret; Trace("q-ext-rewrite-debug2") @@ -341,7 +341,7 @@ Node ExtendedRewriter::extendedRewriteAggr(Node n) return new_ret; } -Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) +Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const { Assert(n.getKind() == itek); Assert(n[1] != n[2]); @@ -561,7 +561,7 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) return new_ret; } -Node ExtendedRewriter::extendedRewriteAndOr(Node n) +Node ExtendedRewriter::extendedRewriteAndOr(Node n) const { // all the below rewrites are aggressive if (!d_aggr) @@ -592,7 +592,7 @@ Node ExtendedRewriter::extendedRewriteAndOr(Node n) return new_ret; } -Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) +Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) const { Assert(n.getKind() != ITE); if (n.isClosure()) @@ -715,7 +715,7 @@ Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) return Node::null(); } -Node ExtendedRewriter::extendedRewriteNnf(Node ret) +Node ExtendedRewriter::extendedRewriteNnf(Node ret) const { Assert(ret.getKind() == NOT); @@ -761,8 +761,11 @@ Node ExtendedRewriter::extendedRewriteNnf(Node ret) return NodeManager::currentNM()->mkNode(nk, new_children); } -Node ExtendedRewriter::extendedRewriteBcp( - Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node ret) +Node ExtendedRewriter::extendedRewriteBcp(Kind andk, + Kind ork, + Kind notk, + std::map<Kind, bool>& bcp_kinds, + Node ret) const { Kind k = ret.getKind(); Assert(k == andk || k == ork); @@ -926,7 +929,7 @@ Node ExtendedRewriter::extendedRewriteBcp( Node ExtendedRewriter::extendedRewriteFactoring(Kind andk, Kind ork, Kind notk, - Node n) + Node n) const { Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl; NodeManager* nm = NodeManager::currentNM(); @@ -1019,7 +1022,7 @@ Node ExtendedRewriter::extendedRewriteEqRes(Kind andk, Kind notk, std::map<Kind, bool>& bcp_kinds, Node n, - bool isXor) + bool isXor) const { Assert(n.getKind() == andk || n.getKind() == ork); Trace("ext-rew-eqres") << "Eq res: **** INPUT: " << n << std::endl; @@ -1166,7 +1169,7 @@ class SimpSubsumeTrie }; Node ExtendedRewriter::extendedRewriteEqChain( - Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor) + Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor) const { Assert(ret.getKind() == eqk); @@ -1527,9 +1530,10 @@ Node ExtendedRewriter::extendedRewriteEqChain( return Node::null(); } -Node ExtendedRewriter::partialSubstitute(Node n, - const std::map<Node, Node>& assign, - const std::map<Kind, bool>& rkinds) +Node ExtendedRewriter::partialSubstitute( + Node n, + const std::map<Node, Node>& assign, + const std::map<Kind, bool>& rkinds) const { std::unordered_map<TNode, Node> visited; std::unordered_map<TNode, Node>::iterator it; @@ -1601,10 +1605,11 @@ Node ExtendedRewriter::partialSubstitute(Node n, return visited[n]; } -Node ExtendedRewriter::partialSubstitute(Node n, - const std::vector<Node>& vars, - const std::vector<Node>& subs, - const std::map<Kind, bool>& rkinds) +Node ExtendedRewriter::partialSubstitute( + Node n, + const std::vector<Node>& vars, + const std::vector<Node>& subs, + const std::map<Kind, bool>& rkinds) const { Assert(vars.size() == subs.size()); std::map<Node, Node> assign; @@ -1615,7 +1620,7 @@ Node ExtendedRewriter::partialSubstitute(Node n, return partialSubstitute(n, assign, rkinds); } -Node ExtendedRewriter::solveEquality(Node n) +Node ExtendedRewriter::solveEquality(Node n) const { // TODO (#1706) : implement Assert(n.getKind() == EQUAL); @@ -1626,7 +1631,7 @@ Node ExtendedRewriter::solveEquality(Node n) bool ExtendedRewriter::inferSubstitution(Node n, std::vector<Node>& vars, std::vector<Node>& subs, - bool usePred) + bool usePred) const { if (n.getKind() == AND) { @@ -1696,7 +1701,7 @@ bool ExtendedRewriter::inferSubstitution(Node n, return false; } -Node ExtendedRewriter::extendedRewriteStrings(Node ret) +Node ExtendedRewriter::extendedRewriteStrings(Node ret) const { Node new_ret; Trace("q-ext-rewrite-debug") diff --git a/src/theory/quantifiers/extended_rewrite.h b/src/theory/quantifiers/extended_rewrite.h index 047318e86..8996fc441 100644 --- a/src/theory/quantifiers/extended_rewrite.h +++ b/src/theory/quantifiers/extended_rewrite.h @@ -51,7 +51,7 @@ class ExtendedRewriter ExtendedRewriter(bool aggr = true); ~ExtendedRewriter() {} /** return the extended rewritten form of n */ - Node extendedRewrite(Node n); + Node extendedRewrite(Node n) const; private: /** @@ -69,16 +69,16 @@ class ExtendedRewriter Node d_true; Node d_false; /** cache that the extended rewritten form of n is ret */ - void setCache(Node n, Node ret); + void setCache(Node n, Node ret) const; /** get the cache for n */ - Node getCache(Node n); + Node getCache(Node n) const; /** add to children * * Adds nc to the vector of children, if dropDup is true, we do not add * nc if it already occurs in children. This method returns false in this * case, otherwise it returns true. */ - bool addToChildren(Node nc, std::vector<Node>& children, bool dropDup); + bool addToChildren(Node nc, std::vector<Node>& children, bool dropDup) const; //--------------------------------------generic utilities /** Rewrite ITE, for example: @@ -92,13 +92,13 @@ class ExtendedRewriter * take. If full is false, then we do only perform rewrites that * strictly decrease the term size of n. */ - Node extendedRewriteIte(Kind itek, Node n, bool full = true); + Node extendedRewriteIte(Kind itek, Node n, bool full = true) const; /** Rewrite AND/OR * * This implements BCP, factoring, and equality resolution for the Boolean * term n whose top symbolic is AND/OR. */ - Node extendedRewriteAndOr(Node n); + Node extendedRewriteAndOr(Node n) const; /** Pull ITE, for example: * * D=C2 ---> false @@ -111,7 +111,7 @@ class ExtendedRewriter * * If this function returns a non-null node ret, then n ---> ret. */ - Node extendedRewritePullIte(Kind itek, Node n); + Node extendedRewritePullIte(Kind itek, Node n) const; /** Negation Normal Form (NNF), for example: * * ~( A & B ) ---> ( ~ A | ~B ) @@ -119,7 +119,7 @@ class ExtendedRewriter * * If this function returns a non-null node ret, then n ---> ret. */ - Node extendedRewriteNnf(Node n); + Node extendedRewriteNnf(Node n) const; /** (type-independent) Boolean constraint propagation, for example: * * ~A & ( B V A ) ---> ~A & B @@ -137,8 +137,11 @@ class ExtendedRewriter * * If this function returns a non-null node ret, then n ---> ret. */ - Node extendedRewriteBcp( - Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node n); + Node extendedRewriteBcp(Kind andk, + Kind ork, + Kind notk, + std::map<Kind, bool>& bcp_kinds, + Node n) const; /** (type-independent) factoring, for example: * * ( A V B ) ^ ( A V C ) ----> A V ( B ^ C ) @@ -147,7 +150,7 @@ class ExtendedRewriter * This function takes as arguments the kinds that specify AND, OR, NOT. * We assume that the children of n do not contain duplicates. */ - Node extendedRewriteFactoring(Kind andk, Kind ork, Kind notk, Node n); + Node extendedRewriteFactoring(Kind andk, Kind ork, Kind notk, Node n) const; /** (type-independent) equality resolution, for example: * * ( A V C ) & ( A = B ) ---> ( B V C ) & ( A = B ) @@ -167,7 +170,7 @@ class ExtendedRewriter Kind notk, std::map<Kind, bool>& bcp_kinds, Node n, - bool isXor = false); + bool isXor = false) const; /** (type-independent) Equality chain rewriting, for example: * * A = ( A = B ) ---> B @@ -178,26 +181,32 @@ class ExtendedRewriter * This function takes as arguments the kinds that specify EQUAL, AND, OR, * and NOT. If the flag isXor is true, the eqk is treated as XOR. */ - Node extendedRewriteEqChain( - Kind eqk, Kind andk, Kind ork, Kind notk, Node n, bool isXor = false); + Node extendedRewriteEqChain(Kind eqk, + Kind andk, + Kind ork, + Kind notk, + Node n, + bool isXor = false) const; /** extended rewrite aggressive * * All aggressive rewriting techniques (those that should be prioritized * at a lower level) go in this function. */ - Node extendedRewriteAggr(Node n); + Node extendedRewriteAggr(Node n) const; /** Decompose right associative chain * * For term f( ... f( f( base, tn ), t{n-1} ) ... t1 ), returns term base, and * appends t1...tn to children. */ - Node decomposeRightAssocChain(Kind k, Node n, std::vector<Node>& children); + Node decomposeRightAssocChain(Kind k, + Node n, + std::vector<Node>& children) const; /** Make right associative chain * * Sorts children to obtain list { tn...t1 }, and returns the term * f( ... f( f( base, tn ), t{n-1} ) ... t1 ). */ - Node mkRightAssocChain(Kind k, Node base, std::vector<Node>& children); + Node mkRightAssocChain(Kind k, Node base, std::vector<Node>& children) const; /** Partial substitute * * Applies the substitution specified by assign to n, recursing only beneath @@ -206,18 +215,18 @@ class ExtendedRewriter */ Node partialSubstitute(Node n, const std::map<Node, Node>& assign, - const std::map<Kind, bool>& rkinds); + const std::map<Kind, bool>& rkinds) const; /** same as above, with vectors */ Node partialSubstitute(Node n, const std::vector<Node>& vars, const std::vector<Node>& subs, - const std::map<Kind, bool>& rkinds); + const std::map<Kind, bool>& rkinds) const; /** solve equality * * If this function returns a non-null node n', then n' is equivalent to n * and is of the form that can be used by inferSubstitution below. */ - Node solveEquality(Node n); + Node solveEquality(Node n) const; /** infer substitution * * If n is an equality of the form x = t, where t is either: @@ -231,12 +240,12 @@ class ExtendedRewriter bool inferSubstitution(Node n, std::vector<Node>& vars, std::vector<Node>& subs, - bool usePred = false); + bool usePred = false) const; /** extended rewrite * * Prints debug information, indicating the rewrite n ---> ret was found. */ - inline void debugExtendedRewrite(Node n, Node ret, const char* c) const; + void debugExtendedRewrite(Node n, Node ret, const char* c) const; //--------------------------------------end generic utilities //--------------------------------------theory-specific top-level calls @@ -245,7 +254,7 @@ class ExtendedRewriter * If this method returns a non-null node ret', then ret is equivalent to * ret'. */ - Node extendedRewriteStrings(Node ret); + Node extendedRewriteStrings(Node ret) const; //--------------------------------------end theory-specific top-level calls }; diff --git a/src/theory/quantifiers/fmf/full_model_check.cpp b/src/theory/quantifiers/fmf/full_model_check.cpp index c3fa664d9..c4f83191b 100644 --- a/src/theory/quantifiers/fmf/full_model_check.cpp +++ b/src/theory/quantifiers/fmf/full_model_check.cpp @@ -725,8 +725,11 @@ int FullModelChecker::doExhaustiveInstantiation( FirstOrderModel * fm, Node f, i } // just add the instance d_triedLemmas++; - if (instq->addInstantiation( - f, inst, InferenceId::QUANTIFIERS_INST_FMF_FMC, true)) + if (instq->addInstantiation(f, + inst, + InferenceId::QUANTIFIERS_INST_FMF_FMC, + Node::null(), + true)) { Trace("fmc-debug-inst") << "** Added instantiation." << std::endl; d_addedLemmas++; @@ -875,8 +878,11 @@ bool FullModelChecker::exhaustiveInstantiate(FirstOrderModelFmc* fm, if (ev!=d_true) { Trace("fmc-exh-debug") << ", add!"; //add as instantiation - if (ie->addInstantiation( - f, inst, InferenceId::QUANTIFIERS_INST_FMF_FMC_EXH, true)) + if (ie->addInstantiation(f, + inst, + InferenceId::QUANTIFIERS_INST_FMF_FMC_EXH, + Node::null(), + true)) { Trace("fmc-exh-debug") << " ...success."; addedLemmas++; diff --git a/src/theory/quantifiers/fmf/model_builder.h b/src/theory/quantifiers/fmf/model_builder.h index cfccd4d93..a767af47a 100644 --- a/src/theory/quantifiers/fmf/model_builder.h +++ b/src/theory/quantifiers/fmf/model_builder.h @@ -56,8 +56,6 @@ class QModelBuilder : public TheoryEngineModelBuilder virtual int doExhaustiveInstantiation( FirstOrderModel * fm, Node f, int effort ) { return false; } //whether to construct model virtual bool optUseModel(); - /** exist instantiation ? */ - virtual bool existsInstantiation( Node f, InstMatch& m, bool modEq = true, bool modInst = false ) { return false; } //debug model void debugModel(TheoryModel* m) override; //statistics diff --git a/src/theory/quantifiers/fmf/model_engine.cpp b/src/theory/quantifiers/fmf/model_engine.cpp index 747b0621f..e58f66d0b 100644 --- a/src/theory/quantifiers/fmf/model_engine.cpp +++ b/src/theory/quantifiers/fmf/model_engine.cpp @@ -301,8 +301,11 @@ void ModelEngine::exhaustiveInstantiate( Node f, int effort ){ Debug("fmf-model-eval") << "* Add instantiation " << m << std::endl; triedLemmas++; //add as instantiation - if (inst->addInstantiation( - f, m.d_vals, InferenceId::QUANTIFIERS_INST_FMF_EXH, true)) + if (inst->addInstantiation(f, + m.d_vals, + InferenceId::QUANTIFIERS_INST_FMF_EXH, + Node::null(), + true)) { addedLemmas++; if (d_qstate.isInConflict()) diff --git a/src/theory/quantifiers/instantiate.cpp b/src/theory/quantifiers/instantiate.cpp index 268d1371f..05361eaa1 100644 --- a/src/theory/quantifiers/instantiate.cpp +++ b/src/theory/quantifiers/instantiate.cpp @@ -101,8 +101,8 @@ void Instantiate::addRewriter(InstantiationRewriter* ir) bool Instantiate::addInstantiation(Node q, std::vector<Node>& terms, InferenceId id, + Node pfArg, bool mkRep, - bool modEq, bool doVts) { // For resource-limiting (also does a time check). @@ -229,7 +229,7 @@ bool Instantiate::addInstantiation(Node q, } // record the instantiation - bool recorded = recordInstantiationInternal(q, terms, modEq); + bool recorded = recordInstantiationInternal(q, terms); if (!recorded) { Trace("inst-add-debug") << " --> Already exists (no record)." << std::endl; @@ -250,7 +250,8 @@ bool Instantiate::addInstantiation(Node q, Trace("inst-add-debug") << "Constructing instantiation..." << std::endl; Assert(d_qreg.d_vars[q].size() == terms.size()); // get the instantiation - Node body = getInstantiation(q, d_qreg.d_vars[q], terms, doVts, pfTmp.get()); + Node body = getInstantiation( + q, d_qreg.d_vars[q], terms, id, pfArg, doVts, pfTmp.get()); Node orig_body = body; // now preprocess, storing the trust node for the rewrite TrustNode tpBody = QuantifiersRewriter::preprocess(body, true); @@ -394,12 +395,12 @@ bool Instantiate::addInstantiationExpFail(Node q, std::vector<Node>& terms, std::vector<bool>& failMask, InferenceId id, + Node pfArg, bool mkRep, - bool modEq, bool doVts, bool expFull) { - if (addInstantiation(q, terms, id, mkRep, modEq, doVts)) + if (addInstantiation(q, terms, id, pfArg, mkRep, doVts)) { return true; } @@ -421,7 +422,9 @@ bool Instantiate::addInstantiationExpFail(Node q, subs[vars[i]] = terms[i]; } // get the instantiation body - Node ibody = getInstantiation(q, vars, terms, doVts); + InferenceId idNone = InferenceId::UNKNOWN; + Node nulln; + Node ibody = getInstantiation(q, vars, terms, idNone, nulln, doVts); ibody = Rewriter::rewrite(ibody); for (size_t i = 0; i < tsize; i++) { @@ -450,7 +453,7 @@ bool Instantiate::addInstantiationExpFail(Node q, // check whether the instantiation rewrites to the same thing if (!success) { - Node ibodyc = getInstantiation(q, vars, terms, doVts); + Node ibodyc = getInstantiation(q, vars, terms, idNone, nulln, doVts); ibodyc = Rewriter::rewrite(ibodyc); success = (ibodyc == ibody); Trace("inst-exp-fail") << " rewrite invariant: " << success << std::endl; @@ -521,6 +524,8 @@ bool Instantiate::existsInstantiation(Node q, Node Instantiate::getInstantiation(Node q, std::vector<Node>& vars, std::vector<Node>& terms, + InferenceId id, + Node pfArg, bool doVts, LazyCDProof* pf) { @@ -534,7 +539,19 @@ Node Instantiate::getInstantiation(Node q, // store the proof of the instantiated body, with (open) assumption q if (pf != nullptr) { - pf->addStep(body, PfRule::INSTANTIATE, {q}, terms); + // additional arguments: if the inference id is not unknown, include it, + // followed by the proof argument if non-null. The latter is used e.g. + // to track which trigger caused an instantiation. + std::vector<Node> pfTerms = terms; + if (id != InferenceId::UNKNOWN) + { + pfTerms.push_back(mkInferenceIdNode(id)); + if (!pfArg.isNull()) + { + pfTerms.push_back(pfArg); + } + } + pf->addStep(body, PfRule::INSTANTIATE, {q}, pfTerms); } // run rewriters to rewrite the instantiation in sequence. @@ -564,18 +581,16 @@ Node Instantiate::getInstantiation(Node q, Node Instantiate::getInstantiation(Node q, std::vector<Node>& terms, bool doVts) { Assert(d_qreg.d_vars.find(q) != d_qreg.d_vars.end()); - return getInstantiation(q, d_qreg.d_vars[q], terms, doVts); + return getInstantiation( + q, d_qreg.d_vars[q], terms, InferenceId::UNKNOWN, Node::null(), doVts); } -bool Instantiate::recordInstantiationInternal(Node q, - std::vector<Node>& terms, - bool modEq) +bool Instantiate::recordInstantiationInternal(Node q, std::vector<Node>& terms) { if (options::incrementalSolving()) { Trace("inst-add-debug") - << "Adding into context-dependent inst trie, modEq = " << modEq - << std::endl; + << "Adding into context-dependent inst trie" << std::endl; CDInstMatchTrie* imt; std::map<Node, CDInstMatchTrie*>::iterator it = d_c_inst_match_trie.find(q); if (it != d_c_inst_match_trie.end()) @@ -588,10 +603,10 @@ bool Instantiate::recordInstantiationInternal(Node q, d_c_inst_match_trie[q] = imt; } d_c_inst_match_trie_dom.insert(q); - return imt->addInstMatch(d_qstate, q, terms, modEq); + return imt->addInstMatch(d_qstate, q, terms); } Trace("inst-add-debug") << "Adding into inst trie" << std::endl; - return d_inst_match_trie[q].addInstMatch(d_qstate, q, terms, modEq); + return d_inst_match_trie[q].addInstMatch(d_qstate, q, terms); } bool Instantiate::removeInstantiationInternal(Node q, std::vector<Node>& terms) diff --git a/src/theory/quantifiers/instantiate.h b/src/theory/quantifiers/instantiate.h index eddc7470b..1f380350f 100644 --- a/src/theory/quantifiers/instantiate.h +++ b/src/theory/quantifiers/instantiate.h @@ -139,10 +139,10 @@ class Instantiate : public QuantifiersUtil * @param terms the terms to instantiate with * @param id the identifier of the instantiation lemma sent via the inference * manager + * @param pfArg an additional node to add to the arguments of the INSTANTIATE + * step * @param mkRep whether to take the representatives of the terms in the * range of the substitution m, - * @param modEq whether to check for duplication modulo equality in - * instantiation tries (for performance), * @param doVts whether we must apply virtual term substitution to the * instantiation lemma. * @@ -161,8 +161,8 @@ class Instantiate : public QuantifiersUtil bool addInstantiation(Node q, std::vector<Node>& terms, InferenceId id, + Node pfArg = Node::null(), bool mkRep = false, - bool modEq = false, bool doVts = false); /** * Same as above, but we also compute a vector failMask indicating which @@ -191,8 +191,8 @@ class Instantiate : public QuantifiersUtil std::vector<Node>& terms, std::vector<bool>& failMask, InferenceId id, + Node pfArg = Node::null(), bool mkRep = false, - bool modEq = false, bool doVts = false, bool expFull = true); /** record instantiation @@ -226,6 +226,8 @@ class Instantiate : public QuantifiersUtil Node getInstantiation(Node q, std::vector<Node>& vars, std::vector<Node>& terms, + InferenceId id = InferenceId::UNKNOWN, + Node pfArg = Node::null(), bool doVts = false, LazyCDProof* pf = nullptr); /** get instantiation @@ -293,14 +295,8 @@ class Instantiate : public QuantifiersUtil Statistics d_statistics; private: - /** record instantiation, return true if it was not a duplicate - * - * modEq : whether to check for duplication modulo equality in instantiation - * tries (for performance), - */ - bool recordInstantiationInternal(Node q, - std::vector<Node>& terms, - bool modEq = false); + /** record instantiation, return true if it was not a duplicate */ + bool recordInstantiationInternal(Node q, std::vector<Node>& terms); /** remove instantiation from the cache */ bool removeInstantiationInternal(Node q, std::vector<Node>& terms); /** diff --git a/src/theory/quantifiers/proof_checker.cpp b/src/theory/quantifiers/proof_checker.cpp index f44f2f291..5e02e16a5 100644 --- a/src/theory/quantifiers/proof_checker.cpp +++ b/src/theory/quantifiers/proof_checker.cpp @@ -102,15 +102,16 @@ Node QuantifiersProofRuleChecker::checkInternal( else if (id == PfRule::INSTANTIATE) { Assert(children.size() == 1); + // note we may have more arguments than just the term vector if (children[0].getKind() != FORALL - || args.size() != children[0][0].getNumChildren()) + || args.size() < children[0][0].getNumChildren()) { return Node::null(); } Node body = children[0][1]; std::vector<Node> vars; std::vector<Node> subs; - for (unsigned i = 0, nargs = args.size(); i < nargs; i++) + for (size_t i = 0, nc = children[0][0].getNumChildren(); i < nc; i++) { vars.push_back(children[0][0][i]); subs.push_back(args[i]); diff --git a/src/theory/quantifiers/sygus/cegis_unif.cpp b/src/theory/quantifiers/sygus/cegis_unif.cpp index 28788a5ea..544bdcc5c 100644 --- a/src/theory/quantifiers/sygus/cegis_unif.cpp +++ b/src/theory/quantifiers/sygus/cegis_unif.cpp @@ -19,6 +19,7 @@ #include "expr/sygus_datatype.h" #include "options/quantifiers_options.h" #include "printer/printer.h" +#include "theory/datatypes/sygus_datatype_utils.h" #include "theory/quantifiers/sygus/sygus_unif_rl.h" #include "theory/quantifiers/sygus/synth_conjecture.h" #include "theory/quantifiers/sygus/term_database_sygus.h" @@ -205,8 +206,8 @@ bool CegisUnif::getEnumValues(const std::vector<Node>& enums, if (curr_val < prev_val) { // must have the same size - unsigned prev_size = d_tds->getSygusTermSize(prev_val); - unsigned curr_size = d_tds->getSygusTermSize(curr_val); + unsigned prev_size = datatypes::utils::getSygusTermSize(prev_val); + unsigned curr_size = datatypes::utils::getSygusTermSize(curr_val); Assert(prev_size <= curr_size); if (curr_size == prev_size) { diff --git a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp index 3ae34d82c..f853ac8e8 100644 --- a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp +++ b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp @@ -24,6 +24,7 @@ #include "theory/quantifiers/sygus/term_database_sygus.h" #include <numeric> // for std::iota +#include <sstream> using namespace cvc5::kind; diff --git a/src/theory/quantifiers/sygus/enum_stream_substitution.h b/src/theory/quantifiers/sygus/enum_stream_substitution.h index ea028991b..05c693ace 100644 --- a/src/theory/quantifiers/sygus/enum_stream_substitution.h +++ b/src/theory/quantifiers/sygus/enum_stream_substitution.h @@ -19,12 +19,14 @@ #define CVC5__THEORY__QUANTIFIERS__SYGUS__ENUM_STREAM_SUBSTITUTION_H #include "expr/node.h" -#include "theory/quantifiers/sygus/synth_conjecture.h" +#include "theory/quantifiers/sygus/enum_val_generator.h" namespace cvc5 { namespace theory { namespace quantifiers { +class TermDbSygus; + /** Streamer of different values according to variable permutations * * Generates a new value (modulo rewriting) when queried in which its variables @@ -33,7 +35,7 @@ namespace quantifiers { class EnumStreamPermutation { public: - EnumStreamPermutation(quantifiers::TermDbSygus* tds); + EnumStreamPermutation(TermDbSygus* tds); ~EnumStreamPermutation() {} /** resets utility * @@ -70,7 +72,7 @@ class EnumStreamPermutation private: /** sygus term database of current quantifiers engine */ - quantifiers::TermDbSygus* d_tds; + TermDbSygus* d_tds; /** maps subclass ids to subset of d_vars with that subclass id */ std::map<unsigned, std::vector<Node>> d_var_classes; /** maps variables to subfield types with constructors for @@ -165,7 +167,7 @@ class EnumStreamPermutation class EnumStreamSubstitution { public: - EnumStreamSubstitution(quantifiers::TermDbSygus* tds); + EnumStreamSubstitution(TermDbSygus* tds); ~EnumStreamSubstitution() {} /** initializes utility * @@ -211,7 +213,7 @@ class EnumStreamSubstitution private: /** sygus term database of current quantifiers engine */ - quantifiers::TermDbSygus* d_tds; + TermDbSygus* d_tds; /** type this utility has been initialized for */ TypeNode d_tn; /** current value */ @@ -281,7 +283,7 @@ class EnumStreamSubstitution class EnumStreamConcrete : public EnumValGenerator { public: - EnumStreamConcrete(quantifiers::TermDbSygus* tds) : d_ess(tds) {} + EnumStreamConcrete(TermDbSygus* tds) : d_ess(tds) {} /** initialize this class with enumerator e */ void initialize(Node e) override; /** get that value v was enumerated */ diff --git a/src/theory/quantifiers/sygus/enum_val_generator.h b/src/theory/quantifiers/sygus/enum_val_generator.h new file mode 100644 index 000000000..64c069087 --- /dev/null +++ b/src/theory/quantifiers/sygus/enum_val_generator.h @@ -0,0 +1,62 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Base class for sygus enumerators + */ + +#include "cvc5_private.h" + +#ifndef CVC5__THEORY__QUANTIFIERS__SYGUS__ENUM_VAL_GENERATOR_H +#define CVC5__THEORY__QUANTIFIERS__SYGUS__ENUM_VAL_GENERATOR_H + +#include "expr/node.h" + +namespace cvc5 { +namespace theory { +namespace quantifiers { + +/** + * A base class for generating values for actively-generated enumerators. + * At a high level, the job of this class is to accept a stream of "abstract + * values" a1, ..., an, ..., and generate a (possibly larger) stream of + * "concrete values" c11, ..., c1{m_1}, ..., cn1, ... cn{m_n}, .... + */ +class EnumValGenerator +{ + public: + virtual ~EnumValGenerator() {} + /** initialize this class with enumerator e */ + virtual void initialize(Node e) = 0; + /** Inform this generator that abstract value v was enumerated. */ + virtual void addValue(Node v) = 0; + /** + * Increment this value generator. If this returns false, then we are out of + * values. If this returns true, getCurrent(), if non-null, returns the + * current term. + * + * Notice that increment() may return true and afterwards it may be the case + * getCurrent() is null. We do this so that increment() does not take too + * much time per call, which can be the case for grammars where it is + * difficult to find the next (non-redundant) term. Returning true with + * a null current term gives the caller the chance to interleave other + * reasoning. + */ + virtual bool increment() = 0; + /** Get the current concrete value generated by this class. */ + virtual Node getCurrent() = 0; +}; + +} // namespace quantifiers +} // namespace theory +} // namespace cvc5 + +#endif diff --git a/src/theory/quantifiers/sygus/rcons_type_info.cpp b/src/theory/quantifiers/sygus/rcons_type_info.cpp index 1c62f030d..a1ae53ad1 100644 --- a/src/theory/quantifiers/sygus/rcons_type_info.cpp +++ b/src/theory/quantifiers/sygus/rcons_type_info.cpp @@ -31,7 +31,7 @@ void RConsTypeInfo::initialize(TermDbSygus* tds, NodeManager* nm = NodeManager::currentNM(); SkolemManager* sm = nm->getSkolemManager(); - d_enumerator.reset(new SygusEnumerator(tds, nullptr, s, true)); + d_enumerator.reset(new SygusEnumerator(tds, nullptr, &s, true)); d_enumerator->initialize(sm->mkDummySkolem("sygus_rcons", stn)); d_crd.reset(new CandidateRewriteDatabase(true, false, true, false)); // since initial samples are not always useful for equivalence checks, set diff --git a/src/theory/quantifiers/sygus/sygus_enumerator.cpp b/src/theory/quantifiers/sygus/sygus_enumerator.cpp index 0cf92b373..2dfd41fb4 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator.cpp @@ -20,6 +20,7 @@ #include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "smt/logic_exception.h" +#include "theory/datatypes/sygus_datatype_utils.h" #include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/sygus/synth_engine.h" #include "theory/quantifiers/sygus/type_node_id_trie.h" @@ -33,12 +34,14 @@ namespace quantifiers { SygusEnumerator::SygusEnumerator(TermDbSygus* tds, SynthConjecture* p, - SygusStatistics& s, - bool enumShapes) + SygusStatistics* s, + bool enumShapes, + bool enumAnyConstHoles) : d_tds(tds), d_parent(p), d_stats(s), d_enumShapes(enumShapes), + d_enumAnyConstHoles(enumAnyConstHoles), d_tlEnum(nullptr), d_abortSize(-1) { @@ -54,6 +57,12 @@ void SygusEnumerator::initialize(Node e) d_tlEnum = getMasterEnumForType(d_etype); d_abortSize = options::sygusAbortSize(); + // if we don't have a term database, we don't register symmetry breaking + // lemmas + if (!d_tds) + { + return; + } // Get the statically registered symmetry breaking clauses for e, see if they // can be used for speeding up the enumeration. NodeManager* nm = NodeManager::currentNM(); @@ -141,7 +150,8 @@ Node SygusEnumerator::getCurrent() if (d_sbExcTlCons.find(ret.getOperator()) != d_sbExcTlCons.end()) { Trace("sygus-enum-exc") - << "Exclude (external) : " << d_tds->sygusToBuiltin(ret) << std::endl; + << "Exclude (external) : " << datatypes::utils::sygusToBuiltin(ret) + << std::endl; ret = Node::null(); } } @@ -330,9 +340,12 @@ bool SygusEnumerator::TermCache::addTerm(Node n) Assert(!n.isNull()); if (options::sygusSymBreakDynamic()) { - Node bn = d_tds->sygusToBuiltin(n); - Node bnr = d_tds->getExtRewriter()->extendedRewrite(bn); - ++(d_stats->d_enumTermsRewrite); + Node bn = datatypes::utils::sygusToBuiltin(n); + Node bnr = d_extr.extendedRewrite(bn); + if (d_stats != nullptr) + { + ++(d_stats->d_enumTermsRewrite); + } if (options::sygusRewVerify()) { if (bn != bnr) @@ -358,7 +371,10 @@ bool SygusEnumerator::TermCache::addTerm(Node n) // if we are doing PBE symmetry breaking if (d_eec != nullptr) { - ++(d_stats->d_enumTermsExampleEval); + if (d_stats != nullptr) + { + ++(d_stats->d_enumTermsExampleEval); + } // Is it equivalent under examples? Node bne = d_eec->addSearchVal(d_tn, bnr); if (!bne.isNull()) @@ -374,7 +390,10 @@ bool SygusEnumerator::TermCache::addTerm(Node n) } Trace("sygus-enum-terms") << "tc(" << d_tn << "): term " << bn << std::endl; } - ++(d_stats->d_enumTerms); + if (d_stats != nullptr) + { + ++(d_stats->d_enumTerms); + } d_terms.push_back(n); return true; } @@ -474,8 +493,8 @@ Node SygusEnumerator::TermEnumSlave::getCurrent() Node curr = tc.getTerm(d_index); Trace("sygus-enum-debug2") << "slave(" << d_tn - << "): current : " << d_se->d_tds->sygusToBuiltin(curr) - << ", sizes = " << d_se->d_tds->getSygusTermSize(curr) << " " + << "): current : " << datatypes::utils::sygusToBuiltin(curr) + << ", sizes = " << datatypes::utils::getSygusTermSize(curr) << " " << getCurrentSize() << std::endl; Trace("sygus-enum-debug2") << "slave(" << d_tn << "): indices : " << d_hasIndexNextEnd << " " @@ -560,7 +579,7 @@ void SygusEnumerator::initializeTermCache(TypeNode tn) { eec = d_parent->getExampleEvalCache(d_enum); } - d_tcache[tn].initialize(&d_stats, d_enum, tn, d_tds, eec); + d_tcache[tn].initialize(d_stats, d_enum, tn, d_tds, eec); } SygusEnumerator::TermEnum* SygusEnumerator::getMasterEnumForType(TypeNode tn) @@ -578,7 +597,7 @@ SygusEnumerator::TermEnum* SygusEnumerator::getMasterEnumForType(TypeNode tn) AlwaysAssert(ret); return &d_masterEnum[tn]; } - if (options::sygusRepairConst()) + if (d_enumAnyConstHoles) { std::map<TypeNode, TermEnumMasterFv>::iterator it = d_masterEnumFv.find(tn); if (it != d_masterEnumFv.end()) @@ -720,6 +739,7 @@ bool SygusEnumerator::TermEnumMaster::incrementInternal() // If we are enumerating shapes, the first enumerated term is a free variable. if (d_enumShapes && !d_enumShapesInit) { + Assert(d_tds != nullptr); Node fv = d_tds->getFreeVar(d_tn, 0); d_enumShapesInit = true; d_currTermSet = true; @@ -1083,6 +1103,7 @@ void SygusEnumerator::TermEnumMaster::childrenToShape( Node SygusEnumerator::TermEnumMaster::convertShape( Node n, std::map<TypeNode, int>& vcounter) { + Assert(d_tds != nullptr); NodeManager* nm = NodeManager::currentNM(); std::unordered_map<TNode, Node> visited; std::unordered_map<TNode, Node>::iterator it; @@ -1195,6 +1216,7 @@ bool SygusEnumerator::TermEnumMasterFv::initialize(SygusEnumerator* se, Node SygusEnumerator::TermEnumMasterFv::getCurrent() { + Assert(d_se->d_tds != nullptr); Node ret = d_se->d_tds->getFreeVar(d_tn, d_currSize); Trace("sygus-enum-debug2") << "master_fv(" << d_tn << "): mk " << ret << std::endl; diff --git a/src/theory/quantifiers/sygus/sygus_enumerator.h b/src/theory/quantifiers/sygus/sygus_enumerator.h index 355108957..39e58d5f3 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator.h +++ b/src/theory/quantifiers/sygus/sygus_enumerator.h @@ -22,6 +22,7 @@ #include <unordered_set> #include "expr/node.h" #include "expr/type_node.h" +#include "theory/quantifiers/sygus/enum_val_generator.h" #include "theory/quantifiers/sygus/synth_conjecture.h" #include "theory/quantifiers/sygus/term_database_sygus.h" @@ -56,10 +57,23 @@ class SygusPbe; class SygusEnumerator : public EnumValGenerator { public: - SygusEnumerator(TermDbSygus* tds, - SynthConjecture* p, - SygusStatistics& s, - bool enumShapes = false); + /** + * @param tds Pointer to the term database, required if enumShapes or + * enumAnyConstHoles is true, or if we want to include symmetry breaking from + * lemmas stored in the sygus term database, + * @param p Pointer to the conjecture, required if we wish to do + * conjecture-specific symmetry breaking + * @param s Pointer to the statistics + * @param enumShapes If true, this enumerator will generate terms having any + * number of free variables + * @param enumAnyConstHoles If true, this enumerator will generate terms where + * free variables are the arguments to any-constant constructors. + */ + SygusEnumerator(TermDbSygus* tds = nullptr, + SynthConjecture* p = nullptr, + SygusStatistics* s = nullptr, + bool enumShapes = false, + bool enumAnyConstHoles = false); ~SygusEnumerator() {} /** initialize this class with enumerator e */ void initialize(Node e) override; @@ -77,10 +91,13 @@ class SygusEnumerator : public EnumValGenerator TermDbSygus* d_tds; /** pointer to the synth conjecture that owns this enumerator */ SynthConjecture* d_parent; - /** reference to the statistics of parent */ - SygusStatistics& d_stats; + /** pointer to the statistics */ + SygusStatistics* d_stats; /** Whether we are enumerating shapes */ bool d_enumShapes; + /** Whether we are enumerating free variables as arguments to any-constant + * constructors */ + bool d_enumAnyConstHoles; /** Term cache * * This stores a list of terms for a given sygus type. The key features of @@ -171,6 +188,8 @@ class SygusEnumerator : public EnumValGenerator TypeNode d_tn; /** pointer to term database sygus */ TermDbSygus* d_tds; + /** extended rewriter */ + ExtendedRewriter d_extr; /** * Pointer to the example evaluation cache utility (used for symmetry * breaking). diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_basic.h b/src/theory/quantifiers/sygus/sygus_enumerator_basic.h index bae6f6327..42bce471d 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_basic.h +++ b/src/theory/quantifiers/sygus/sygus_enumerator_basic.h @@ -22,7 +22,7 @@ #include <unordered_set> #include "expr/node.h" #include "expr/type_node.h" -#include "theory/quantifiers/sygus/synth_conjecture.h" +#include "theory/quantifiers/sygus/enum_val_generator.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/type_enumerator.h" diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp b/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp new file mode 100644 index 000000000..7b3236832 --- /dev/null +++ b/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp @@ -0,0 +1,107 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds, Mathias Preiner + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * sygus_enumerator + */ + +#include "theory/quantifiers/sygus/sygus_enumerator_callback.h" + +#include "theory/datatypes/sygus_datatype_utils.h" +#include "theory/quantifiers/sygus/example_eval_cache.h" +#include "theory/quantifiers/sygus/sygus_stats.h" +#include "theory/quantifiers/sygus_sampler.h" + +namespace cvc5 { +namespace theory { +namespace quantifiers { + +SygusEnumeratorCallback::SygusEnumeratorCallback(Node e, SygusStatistics* s) + : d_enum(e), d_stats(s) +{ + d_tn = e.getType(); +} + +bool SygusEnumeratorCallback::addTerm(Node n, std::unordered_set<Node>& bterms) +{ + Node bn = datatypes::utils::sygusToBuiltin(n); + Node bnr = d_extr.extendedRewrite(bn); + if (d_stats != nullptr) + { + ++(d_stats->d_enumTermsRewrite); + } + // call the solver-specific notify term + notifyTermInternal(n, bn, bnr); + // check whether we should keep the term, which is based on the callback, + // and the builtin terms + // First, must be unique up to rewriting + if (bterms.find(bnr) != bterms.end()) + { + Trace("sygus-enum-exc") << "Exclude: " << bn << std::endl; + return false; + } + // insert to builtin term cache, regardless of whether it is redundant + // based on the callback + bterms.insert(bnr); + // callback-specific add term + if (!addTermInternal(n, bn, bnr)) + { + Trace("sygus-enum-exc") + << "Exclude: " << bn << " due to callback" << std::endl; + return false; + } + Trace("sygus-enum-terms") << "tc(" << d_tn << "): term " << bn << std::endl; + return true; +} + +SygusEnumeratorCallbackDefault::SygusEnumeratorCallbackDefault( + Node e, SygusStatistics* s, ExampleEvalCache* eec, SygusSampler* ssrv) + : SygusEnumeratorCallback(e, s), d_eec(eec), d_samplerRrV(ssrv) +{ +} +void SygusEnumeratorCallbackDefault::notifyTermInternal(Node n, + Node bn, + Node bnr) +{ + if (d_samplerRrV != nullptr) + { + d_samplerRrV->checkEquivalent(bn, bnr); + } +} + +bool SygusEnumeratorCallbackDefault::addTermInternal(Node n, Node bn, Node bnr) +{ + // if we are doing PBE symmetry breaking + if (d_eec != nullptr) + { + if (d_stats != nullptr) + { + ++(d_stats->d_enumTermsExampleEval); + } + // Is it equivalent under examples? + Node bne = d_eec->addSearchVal(d_tn, bnr); + if (!bne.isNull()) + { + if (bnr != bne) + { + Trace("sygus-enum-exc") + << "Exclude (by examples): " << bn << ", since we already have " + << bne << std::endl; + return false; + } + } + } + return true; +} + +} // namespace quantifiers +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_callback.h b/src/theory/quantifiers/sygus/sygus_enumerator_callback.h new file mode 100644 index 000000000..545440eef --- /dev/null +++ b/src/theory/quantifiers/sygus/sygus_enumerator_callback.h @@ -0,0 +1,110 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds, Mathias Preiner + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * sygus_enumerator + */ + +#include "cvc5_private.h" + +#ifndef CVC5__THEORY__QUANTIFIERS__SYGUS__SYGUS_ENUMERATOR_CALLBACK_H +#define CVC5__THEORY__QUANTIFIERS__SYGUS__SYGUS_ENUMERATOR_CALLBACK_H + +#include <unordered_set> + +#include "expr/node.h" +#include "theory/quantifiers/extended_rewrite.h" + +namespace cvc5 { +namespace theory { +namespace quantifiers { + +class ExampleEvalCache; +class SygusStatistics; +class SygusSampler; + +/** + * Base class for callbacks in the fast enumerator. This allows a user to + * provide custom criteria for whether or not enumerated values should be + * considered. + */ +class SygusEnumeratorCallback +{ + public: + SygusEnumeratorCallback(Node e, SygusStatistics* s = nullptr); + virtual ~SygusEnumeratorCallback() {} + /** + * Add term, return true if the term should be considered in the enumeration. + * Notice that returning false indicates that n should not be considered as a + * subterm of any other term in the enumeration. + * + * @param n The SyGuS term + * @param bterms The (rewritten, builtin) terms we have already enumerated + * @return true if n should be considered in the enumeration. + */ + virtual bool addTerm(Node n, std::unordered_set<Node>& bterms) = 0; + + protected: + /** + * Callback-specific notification of the above + * + * @param n The SyGuS term + * @param bn The builtin version of the enumerated term + * @param bnr The (extended) rewritten form of bn + */ + virtual void notifyTermInternal(Node n, Node bn, Node bnr) = 0; + /** + * Callback-specific add term + * + * @param n The SyGuS term + * @param bn The builtin version of the enumerated term + * @param bnr The (extended) rewritten form of bn + * @return true if the term should be considered in the enumeration. + */ + virtual bool addTermInternal(Node n, Node bn, Node bnr) = 0; + /** The enumerator */ + Node d_enum; + /** The type of enum */ + TypeNode d_tn; + /** extended rewriter */ + ExtendedRewriter d_extr; + /** pointer to the statistics */ + SygusStatistics* d_stats; +}; + +class SygusEnumeratorCallbackDefault : public SygusEnumeratorCallback +{ + public: + SygusEnumeratorCallbackDefault(Node e, + SygusStatistics* s = nullptr, + ExampleEvalCache* eec = nullptr, + SygusSampler* ssrv = nullptr); + virtual ~SygusEnumeratorCallbackDefault() {} + + protected: + /** Notify that bn / bnr is an enumerated builtin, rewritten form of a term */ + void notifyTermInternal(Node n, Node bn, Node bnr) override; + /** Add term, return true if n should be considered in the enumeration */ + bool addTermInternal(Node n, Node bn, Node bnr) override; + /** + * Pointer to the example evaluation cache utility (used for symmetry + * breaking). + */ + ExampleEvalCache* d_eec; + /** sampler (for --sygus-rr-verify) */ + SygusSampler* d_samplerRrV; +}; + +} // namespace quantifiers +} // namespace theory +} // namespace cvc5 + +#endif /* CVC5__THEORY__QUANTIFIERS__SYGUS__SYGUS_ENUMERATOR_CALLBACK_H */ diff --git a/src/theory/quantifiers/sygus/sygus_explain.cpp b/src/theory/quantifiers/sygus/sygus_explain.cpp index 395f16beb..23c315f42 100644 --- a/src/theory/quantifiers/sygus/sygus_explain.cpp +++ b/src/theory/quantifiers/sygus/sygus_explain.cpp @@ -18,6 +18,7 @@ #include "expr/dtype.h" #include "expr/dtype_cons.h" #include "smt/logic_exception.h" +#include "theory/datatypes/sygus_datatype_utils.h" #include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/sygus/sygus_invariance.h" #include "theory/quantifiers/sygus/term_database_sygus.h" @@ -220,7 +221,7 @@ void SygusExplain::getExplanationFor(TermRecBuild& trb, // we are tracking term size if positive if (sz >= 0) { - int s = d_tdb->getSygusTermSize(vn[i]); + int s = datatypes::utils::getSygusTermSize(vn[i]); sz = sz - s; } } diff --git a/src/theory/quantifiers/sygus/sygus_pbe.cpp b/src/theory/quantifiers/sygus/sygus_pbe.cpp index 86d0bbc8e..892ee6dd4 100644 --- a/src/theory/quantifiers/sygus/sygus_pbe.cpp +++ b/src/theory/quantifiers/sygus/sygus_pbe.cpp @@ -15,6 +15,7 @@ #include "theory/quantifiers/sygus/sygus_pbe.h" #include "options/quantifiers_options.h" +#include "theory/datatypes/sygus_datatype_utils.h" #include "theory/quantifiers/sygus/example_infer.h" #include "theory/quantifiers/sygus/sygus_unif_io.h" #include "theory/quantifiers/sygus/synth_conjecture.h" @@ -180,7 +181,7 @@ bool SygusPbe::constructCandidates(const std::vector<Node>& enums, Trace("sygus-pbe-enum") << std::endl; if (!enum_values[i].isNull()) { - unsigned sz = d_tds->getSygusTermSize(enum_values[i]); + unsigned sz = datatypes::utils::getSygusTermSize(enum_values[i]); szs.push_back(sz); if (i == 0 || sz < min_term_size) { diff --git a/src/theory/quantifiers/sygus/sygus_unif.cpp b/src/theory/quantifiers/sygus/sygus_unif.cpp index 16ca1f4e6..00370ffa2 100644 --- a/src/theory/quantifiers/sygus/sygus_unif.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif.cpp @@ -15,6 +15,7 @@ #include "theory/quantifiers/sygus/sygus_unif.h" +#include "theory/datatypes/sygus_datatype_utils.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" #include "util/random.h" @@ -52,7 +53,7 @@ Node SygusUnif::getMinimalTerm(const std::vector<Node>& terms) unsigned ssize = 0; if (it == d_termToSize.end()) { - ssize = d_tds->getSygusTermSize(n); + ssize = datatypes::utils::getSygusTermSize(n); d_termToSize[n] = ssize; } else diff --git a/src/theory/quantifiers/sygus/sygus_unif_io.cpp b/src/theory/quantifiers/sygus/sygus_unif_io.cpp index 8c8f5ccd4..8207a07f2 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_io.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_io.cpp @@ -16,6 +16,7 @@ #include "theory/quantifiers/sygus/sygus_unif_io.h" #include "options/quantifiers_options.h" +#include "theory/datatypes/sygus_datatype_utils.h" #include "theory/evaluator.h" #include "theory/quantifiers/sygus/example_infer.h" #include "theory/quantifiers/sygus/synth_conjecture.h" @@ -835,7 +836,8 @@ Node SygusUnifIo::constructSolutionNode(std::vector<Node>& lemmas) if (!vcc.isNull() && (d_solution.isNull() || (!d_solution.isNull() - && d_tds->getSygusTermSize(vcc) < d_sol_term_size))) + && datatypes::utils::getSygusTermSize(vcc) + < d_sol_term_size))) { if (Trace.isOn("sygus-pbe")) { @@ -846,7 +848,7 @@ Node SygusUnifIo::constructSolutionNode(std::vector<Node>& lemmas) } d_solution = vcc; newSolution = vcc; - d_sol_term_size = d_tds->getSygusTermSize(vcc); + d_sol_term_size = datatypes::utils::getSygusTermSize(vcc); Trace("sygus-pbe-sol") << "PBE solution size: " << d_sol_term_size << std::endl; // We've determined its feasible, now, enable information gain and diff --git a/src/theory/quantifiers/sygus/synth_conjecture.cpp b/src/theory/quantifiers/sygus/synth_conjecture.cpp index 1ddc2fa22..73bd6b8a4 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.cpp +++ b/src/theory/quantifiers/sygus/synth_conjecture.cpp @@ -827,7 +827,10 @@ Node SynthConjecture::getEnumeratedValue(Node e, bool& activeIncomplete) == options::SygusActiveGenMode::ENUM || options::sygusActiveGenMode() == options::SygusActiveGenMode::AUTO); - d_evg[e].reset(new SygusEnumerator(d_tds, this, d_stats)); + // if sygus repair const is enabled, we enumerate terms with free + // variables as arguments to any-constant constructors + d_evg[e].reset(new SygusEnumerator( + d_tds, this, &d_stats, false, options::sygusRepairConst())); } } Trace("sygus-active-gen") diff --git a/src/theory/quantifiers/sygus/synth_conjecture.h b/src/theory/quantifiers/sygus/synth_conjecture.h index e6645ddf2..04999da0d 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.h +++ b/src/theory/quantifiers/sygus/synth_conjecture.h @@ -26,6 +26,7 @@ #include "theory/quantifiers/sygus/cegis.h" #include "theory/quantifiers/sygus/cegis_core_connective.h" #include "theory/quantifiers/sygus/cegis_unif.h" +#include "theory/quantifiers/sygus/enum_val_generator.h" #include "theory/quantifiers/sygus/example_eval_cache.h" #include "theory/quantifiers/sygus/example_infer.h" #include "theory/quantifiers/sygus/sygus_process_conj.h" @@ -42,37 +43,6 @@ class CegGrammarConstructor; class SygusPbe; class SygusStatistics; -/** - * A base class for generating values for actively-generated enumerators. - * At a high level, the job of this class is to accept a stream of "abstract - * values" a1, ..., an, ..., and generate a (possibly larger) stream of - * "concrete values" c11, ..., c1{m_1}, ..., cn1, ... cn{m_n}, .... - */ -class EnumValGenerator -{ - public: - virtual ~EnumValGenerator() {} - /** initialize this class with enumerator e */ - virtual void initialize(Node e) = 0; - /** Inform this generator that abstract value v was enumerated. */ - virtual void addValue(Node v) = 0; - /** - * Increment this value generator. If this returns false, then we are out of - * values. If this returns true, getCurrent(), if non-null, returns the - * current term. - * - * Notice that increment() may return true and afterwards it may be the case - * getCurrent() is null. We do this so that increment() does not take too - * much time per call, which can be the case for grammars where it is - * difficult to find the next (non-redundant) term. Returning true with - * a null current term gives the caller the chance to interleave other - * reasoning. - */ - virtual bool increment() = 0; - /** Get the current concrete value generated by this class. */ - virtual Node getCurrent() = 0; -}; - /** a synthesis conjecture * This class implements approaches for a synthesis conjecture, given by data * member d_quant. diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index 826563401..3b0ea3312 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -359,23 +359,6 @@ Node TermDbSygus::sygusToBuiltin(Node n, TypeNode tn) return ret; } -unsigned TermDbSygus::getSygusTermSize( Node n ){ - if (n.getKind() != APPLY_CONSTRUCTOR) - { - return 0; - } - unsigned sum = 0; - for (unsigned i = 0; i < n.getNumChildren(); i++) - { - sum += getSygusTermSize(n[i]); - } - const DType& dt = datatypes::utils::datatypeOf(n.getOperator()); - int cindex = datatypes::utils::indexOf(n.getOperator()); - Assert(cindex >= 0 && cindex < (int)dt.getNumConstructors()); - unsigned weight = dt[cindex].getWeight(); - return weight + sum; -} - bool TermDbSygus::registerSygusType(TypeNode tn) { std::map<TypeNode, bool>::iterator it = d_registerStatus.find(tn); diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index e0a812069..80411b258 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -456,7 +456,6 @@ class TermDbSygus { Node getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs ); Node getNormalized(TypeNode t, Node prog); - unsigned getSygusTermSize( Node n ); /** involves div-by-zero */ bool involvesDivByZero( Node n ); /** get anchor */ diff --git a/src/theory/quantifiers/term_database.cpp b/src/theory/quantifiers/term_database.cpp index bedab16f1..523b84e65 100644 --- a/src/theory/quantifiers/term_database.cpp +++ b/src/theory/quantifiers/term_database.cpp @@ -480,9 +480,8 @@ void TermDb::addTermHo(Node n) Node psk; if (itp == d_ho_fun_op_purify.end()) { - psk = sm->mkDummySkolem("pfun", - curr.getType(), - "purify for function operator term indexing"); + psk = sm->mkPurifySkolem( + curr, "pfun", "purify for function operator term indexing"); d_ho_fun_op_purify[curr] = psk; // we do not add it to d_ops since it is an internal operator } @@ -1034,7 +1033,10 @@ bool TermDb::reset( Theory::Effort effort ){ eq = itpe->second; } Trace("quant-ho") << "- assert purify equality : " << eq << std::endl; - ee->assertEquality(eq, true, eq); + // Note that ee may be the central equality engine, in which case this + // equality is explained trivially with "true", since both sides of + // eq are HOL and FOL encodings of the same thing. + ee->assertEquality(eq, true, d_true); if (!ee->consistent()) { // In some rare cases, purification functions (in the domain of |