diff options
Diffstat (limited to 'src/theory/quantifiers/sygus/cegis.cpp')
-rw-r--r-- | src/theory/quantifiers/sygus/cegis.cpp | 322 |
1 files changed, 210 insertions, 112 deletions
diff --git a/src/theory/quantifiers/sygus/cegis.cpp b/src/theory/quantifiers/sygus/cegis.cpp index e8d29835a..47053080a 100644 --- a/src/theory/quantifiers/sygus/cegis.cpp +++ b/src/theory/quantifiers/sygus/cegis.cpp @@ -51,10 +51,9 @@ bool Cegis::initialize(Node n, } // initialize an enumerator for each candidate - TermDbSygus* tds = d_qe->getTermDatabaseSygus(); for (unsigned i = 0; i < candidates.size(); i++) { - tds->registerEnumerator(candidates[i], candidates[i], d_parent); + d_tds->registerEnumerator(candidates[i], candidates[i], d_parent); } return true; } @@ -68,10 +67,6 @@ void Cegis::getTermList(const std::vector<Node>& candidates, bool Cegis::addEvalLemmas(const std::vector<Node>& candidates, const std::vector<Node>& candidate_values) { - if (!options::sygusDirectEval()) - { - return false; - } NodeManager* nm = NodeManager::currentNM(); bool addedEvalLemmas = false; if (options::sygusCRefEval()) @@ -96,18 +91,21 @@ bool Cegis::addEvalLemmas(const std::vector<Node>& candidates, add the lemmas below as well, in parallel. */ } } + if (!options::sygusDirectEval()) + { + return addedEvalLemmas; + } Trace("cegqi-engine") << " *** Do direct evaluation..." << std::endl; std::vector<Node> eager_terms, eager_vals, eager_exps; - TermDbSygus* tds = d_qe->getTermDatabaseSygus(); for (unsigned i = 0, size = candidates.size(); i < size; ++i) { Trace("cegqi-debug") << " register " << candidates[i] << " -> " << candidate_values[i] << std::endl; - tds->registerModelValue(candidates[i], - candidate_values[i], - eager_terms, - eager_vals, - eager_exps); + d_tds->registerModelValue(candidates[i], + candidate_values[i], + eager_terms, + eager_vals, + eager_exps); } Trace("cegqi-debug") << "...produced " << eager_terms.size() << " eager evaluation lemmas.\n"; @@ -148,11 +146,136 @@ bool Cegis::constructCandidates(const std::vector<Node>& enums, return true; } +void Cegis::addRefinementLemma(Node lem) +{ + d_refinement_lemmas.push_back(lem); + // apply existing substitution + Node slem = lem; + if (!d_rl_eval_hds.empty()) + { + slem = lem.substitute(d_rl_eval_hds.begin(), + d_rl_eval_hds.end(), + d_rl_vals.begin(), + d_rl_vals.end()); + } + // rewrite with extended rewriter + slem = d_tds->getExtRewriter()->extendedRewrite(slem); + std::vector<Node> waiting; + waiting.push_back(lem); + unsigned wcounter = 0; + // while we are not done adding lemmas + while (wcounter < waiting.size()) + { + // add the conjunct, possibly propagating + addRefinementLemmaConjunct(wcounter, waiting); + wcounter++; + } +} + +void Cegis::addRefinementLemmaConjunct(unsigned wcounter, + std::vector<Node>& waiting) +{ + Node lem = waiting[wcounter]; + lem = Rewriter::rewrite(lem); + // apply substitution and rewrite if applicable + if (lem.isConst()) + { + if (!lem.getConst<bool>()) + { + // conjecture is infeasible + } + else + { + return; + } + } + // break into conjunctions + if (lem.getKind() == AND) + { + for (const Node& lc : lem) + { + waiting.push_back(lc); + } + return; + } + // does this correspond to a substitution? + NodeManager* nm = NodeManager::currentNM(); + TNode term; + TNode val; + if (lem.getKind() == EQUAL) + { + for (unsigned i = 0; i < 2; i++) + { + if (lem[i].isConst() && d_tds->isEvaluationPoint(lem[1 - i])) + { + term = lem[1 - i]; + val = lem[i]; + break; + } + } + } + else + { + term = lem.getKind() == NOT ? lem[0] : lem; + // predicate case: the conjunct is a (negated) evaluation point + if (d_tds->isEvaluationPoint(term)) + { + val = nm->mkConst(lem.getKind() != NOT); + } + } + if (!val.isNull()) + { + if (d_refinement_lemma_unit.find(lem) != d_refinement_lemma_unit.end()) + { + // already added + return; + } + Trace("cegis-rl") << "* cegis-rl: propagate: " << term << " -> " << val + << std::endl; + d_rl_eval_hds.push_back(term); + d_rl_vals.push_back(val); + d_refinement_lemma_unit.insert(lem); + // apply to waiting lemmas beyond this one + for (unsigned i = wcounter + 1, size = waiting.size(); i < size; i++) + { + waiting[i] = waiting[i].substitute(term, val); + } + // apply to all existing refinement lemmas + std::vector<Node> to_rem; + for (const Node& rl : d_refinement_lemma_conj) + { + Node srl = rl.substitute(term, val); + if (srl != rl) + { + Trace("cegis-rl") << "* cegis-rl: replace: " << rl << " -> " << srl + << std::endl; + waiting.push_back(srl); + to_rem.push_back(rl); + } + } + for (const Node& tr : to_rem) + { + d_refinement_lemma_conj.erase(tr); + } + } + else + { + if (Trace.isOn("cegis-rl")) + { + if (d_refinement_lemma_conj.find(lem) == d_refinement_lemma_conj.end()) + { + Trace("cegis-rl") << "cegis-rl: add: " << lem << std::endl; + } + } + d_refinement_lemma_conj.insert(lem); + } +} + void Cegis::registerRefinementLemma(const std::vector<Node>& vars, Node lem, std::vector<Node>& lems) { - d_refinement_lemmas.push_back(lem); + addRefinementLemma(lem); // Make the refinement lemma and add it to lems. // This lemma is guarded by the parent's guard, which has the semantics // "this conjecture has a solution", hence this lemma states: @@ -168,118 +291,93 @@ void Cegis::getRefinementEvalLemmas(const std::vector<Node>& vs, std::vector<Node>& lems) { Trace("sygus-cref-eval") << "Cref eval : conjecture has " - << getNumRefinementLemmas() << " refinement lemmas." + << d_refinement_lemma_unit.size() << " unit and " + << d_refinement_lemma_conj.size() + << " non-unit refinement lemma conjunctions." << std::endl; - unsigned nlemmas = getNumRefinementLemmas(); - if (nlemmas > 0 || options::cegisSample() != CEGIS_SAMPLE_NONE) - { - Assert(vs.size() == ms.size()); + Assert(vs.size() == ms.size()); - TermDbSygus* tds = d_qe->getTermDatabaseSygus(); - NodeManager* nm = NodeManager::currentNM(); + NodeManager* nm = NodeManager::currentNM(); - Node nfalse = nm->mkConst(false); - Node neg_guard = d_parent->getGuard().negate(); - for (unsigned i = 0; i <= nlemmas; i++) + Node nfalse = nm->mkConst(false); + Node neg_guard = d_parent->getGuard().negate(); + for (unsigned r = 0; r < 2; r++) + { + std::unordered_set<Node, NodeHashFunction>& rlemmas = + r == 0 ? d_refinement_lemma_unit : d_refinement_lemma_conj; + for (const Node& lem : rlemmas) { - if (i == nlemmas) - { - bool addedSample = false; - // find a new one by sampling, if applicable - if (options::cegisSample() != CEGIS_SAMPLE_NONE) - { - addedSample = sampleAddRefinementLemma(vs, ms, lems); - } - if (!addedSample) - { - return; - } - } - Node lem; + Assert(!lem.isNull()); std::map<Node, Node> visited; std::map<Node, std::vector<Node> > exp; - lem = getRefinementLemma(i); - if (!lem.isNull()) + EvalSygusInvarianceTest vsit; + Trace("sygus-cref-eval") << "Check refinement lemma conjunct " << lem + << " against current model." << std::endl; + Trace("sygus-cref-eval2") << "Check refinement lemma conjunct " << lem + << " against current model." << std::endl; + Node cre_lem; + Node lemcs = lem.substitute(vs.begin(), vs.end(), ms.begin(), ms.end()); + Trace("sygus-cref-eval2") + << "...under substitution it is : " << lemcs << std::endl; + Node lemcsu = vsit.doEvaluateWithUnfolding(d_tds, lemcs); + Trace("sygus-cref-eval2") + << "...after unfolding is : " << lemcsu << std::endl; + if (lemcsu.isConst() && !lemcsu.getConst<bool>()) { - std::vector<Node> lem_conj; - // break into conjunctions - if (lem.getKind() == kind::AND) + std::vector<Node> msu; + std::vector<Node> mexp; + msu.insert(msu.end(), ms.begin(), ms.end()); + std::map<TypeNode, int> var_count; + for (unsigned k = 0; k < vs.size(); k++) { - for (unsigned i = 0; i < lem.getNumChildren(); i++) - { - lem_conj.push_back(lem[i]); - } + vsit.setUpdatedTerm(msu[k]); + msu[k] = vs[k]; + // substitute for everything except this + Node sconj = + lem.substitute(vs.begin(), vs.end(), msu.begin(), msu.end()); + vsit.init(sconj, vs[k], nfalse); + // get minimal explanation for this + Node ut = vsit.getUpdatedTerm(); + Trace("sygus-cref-eval2-debug") + << " compute min explain of : " << vs[k] << " = " << ut + << std::endl; + d_tds->getExplain()->getExplanationFor( + vs[k], ut, mexp, vsit, var_count, false); + Trace("sygus-cref-eval2-debug") << "exp now: " << mexp << std::endl; + msu[k] = vsit.getUpdatedTerm(); + Trace("sygus-cref-eval2-debug") + << "updated term : " << msu[k] << std::endl; } - else + if (!mexp.empty()) { - lem_conj.push_back(lem); + Node en = mexp.size() == 1 ? mexp[0] : nm->mkNode(kind::AND, mexp); + cre_lem = nm->mkNode(kind::OR, en.negate(), neg_guard); } - EvalSygusInvarianceTest vsit; - for (unsigned j = 0; j < lem_conj.size(); j++) + else { - Node lemc = lem_conj[j]; - Trace("sygus-cref-eval") << "Check refinement lemma conjunct " << lemc - << " against current model." << std::endl; - Trace("sygus-cref-eval2") << "Check refinement lemma conjunct " - << lemc << " against current model." - << std::endl; - Node cre_lem; - Node lemcs = - lemc.substitute(vs.begin(), vs.end(), ms.begin(), ms.end()); - Trace("sygus-cref-eval2") << "...under substitution it is : " << lemcs - << std::endl; - Node lemcsu = vsit.doEvaluateWithUnfolding(tds, lemcs); - Trace("sygus-cref-eval2") << "...after unfolding is : " << lemcsu - << std::endl; - if (lemcsu.isConst() && !lemcsu.getConst<bool>()) - { - std::vector<Node> msu; - std::vector<Node> mexp; - msu.insert(msu.end(), ms.begin(), ms.end()); - std::map<TypeNode, int> var_count; - for (unsigned k = 0; k < vs.size(); k++) - { - vsit.setUpdatedTerm(msu[k]); - msu[k] = vs[k]; - // substitute for everything except this - Node sconj = - lemc.substitute(vs.begin(), vs.end(), msu.begin(), msu.end()); - vsit.init(sconj, vs[k], nfalse); - // get minimal explanation for this - Node ut = vsit.getUpdatedTerm(); - Trace("sygus-cref-eval2-debug") - << " compute min explain of : " << vs[k] << " = " << ut - << std::endl; - tds->getExplain()->getExplanationFor( - vs[k], ut, mexp, vsit, var_count, false); - Trace("sygus-cref-eval2-debug") - << "exp now: " << mexp << std::endl; - msu[k] = vsit.getUpdatedTerm(); - Trace("sygus-cref-eval2-debug") - << "updated term : " << msu[k] << std::endl; - } - if (!mexp.empty()) - { - Node en = - mexp.size() == 1 ? mexp[0] : nm->mkNode(kind::AND, mexp); - cre_lem = nm->mkNode(kind::OR, en.negate(), neg_guard); - } - else - { - cre_lem = neg_guard; - } - } - if (!cre_lem.isNull()) - { - if (std::find(lems.begin(), lems.end(), cre_lem) == lems.end()) - { - Trace("sygus-cref-eval") << "...produced lemma : " << cre_lem - << std::endl; - lems.push_back(cre_lem); - } - } + cre_lem = neg_guard; } } + if (!cre_lem.isNull() + && std::find(lems.begin(), lems.end(), cre_lem) == lems.end()) + { + Trace("sygus-cref-eval") + << "...produced lemma : " << cre_lem << std::endl; + lems.push_back(cre_lem); + } + } + if (!lems.empty()) + { + break; + } + } + // if we didn't add a lemma, trying sampling to add one + if (options::cegisSample() != CEGIS_SAMPLE_NONE && lems.empty()) + { + if (sampleAddRefinementLemma(vs, ms, lems)) + { + // restart (should be guaranteed to add evaluation lemmas + getRefinementEvalLemmas(vs, ms, lems); } } } @@ -344,7 +442,7 @@ bool Cegis::sampleAddRefinementLemma(const std::vector<Node>& candidates, Trace("cegis-sample") << std::endl; } Trace("cegqi-engine") << " *** Refine by sampling" << std::endl; - d_refinement_lemmas.push_back(rlem); + addRefinementLemma(rlem); // if trust, we are not interested in sending out refinement lemmas if (options::cegisSample() != CEGIS_SAMPLE_TRUST) { |