summaryrefslogtreecommitdiff
path: root/src/theory/quantifiers/sygus/cegis.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/quantifiers/sygus/cegis.cpp')
-rw-r--r--src/theory/quantifiers/sygus/cegis.cpp322
1 files changed, 210 insertions, 112 deletions
diff --git a/src/theory/quantifiers/sygus/cegis.cpp b/src/theory/quantifiers/sygus/cegis.cpp
index e8d29835a..47053080a 100644
--- a/src/theory/quantifiers/sygus/cegis.cpp
+++ b/src/theory/quantifiers/sygus/cegis.cpp
@@ -51,10 +51,9 @@ bool Cegis::initialize(Node n,
}
// initialize an enumerator for each candidate
- TermDbSygus* tds = d_qe->getTermDatabaseSygus();
for (unsigned i = 0; i < candidates.size(); i++)
{
- tds->registerEnumerator(candidates[i], candidates[i], d_parent);
+ d_tds->registerEnumerator(candidates[i], candidates[i], d_parent);
}
return true;
}
@@ -68,10 +67,6 @@ void Cegis::getTermList(const std::vector<Node>& candidates,
bool Cegis::addEvalLemmas(const std::vector<Node>& candidates,
const std::vector<Node>& candidate_values)
{
- if (!options::sygusDirectEval())
- {
- return false;
- }
NodeManager* nm = NodeManager::currentNM();
bool addedEvalLemmas = false;
if (options::sygusCRefEval())
@@ -96,18 +91,21 @@ bool Cegis::addEvalLemmas(const std::vector<Node>& candidates,
add the lemmas below as well, in parallel. */
}
}
+ if (!options::sygusDirectEval())
+ {
+ return addedEvalLemmas;
+ }
Trace("cegqi-engine") << " *** Do direct evaluation..." << std::endl;
std::vector<Node> eager_terms, eager_vals, eager_exps;
- TermDbSygus* tds = d_qe->getTermDatabaseSygus();
for (unsigned i = 0, size = candidates.size(); i < size; ++i)
{
Trace("cegqi-debug") << " register " << candidates[i] << " -> "
<< candidate_values[i] << std::endl;
- tds->registerModelValue(candidates[i],
- candidate_values[i],
- eager_terms,
- eager_vals,
- eager_exps);
+ d_tds->registerModelValue(candidates[i],
+ candidate_values[i],
+ eager_terms,
+ eager_vals,
+ eager_exps);
}
Trace("cegqi-debug") << "...produced " << eager_terms.size()
<< " eager evaluation lemmas.\n";
@@ -148,11 +146,136 @@ bool Cegis::constructCandidates(const std::vector<Node>& enums,
return true;
}
+void Cegis::addRefinementLemma(Node lem)
+{
+ d_refinement_lemmas.push_back(lem);
+ // apply existing substitution
+ Node slem = lem;
+ if (!d_rl_eval_hds.empty())
+ {
+ slem = lem.substitute(d_rl_eval_hds.begin(),
+ d_rl_eval_hds.end(),
+ d_rl_vals.begin(),
+ d_rl_vals.end());
+ }
+ // rewrite with extended rewriter
+ slem = d_tds->getExtRewriter()->extendedRewrite(slem);
+ std::vector<Node> waiting;
+ waiting.push_back(lem);
+ unsigned wcounter = 0;
+ // while we are not done adding lemmas
+ while (wcounter < waiting.size())
+ {
+ // add the conjunct, possibly propagating
+ addRefinementLemmaConjunct(wcounter, waiting);
+ wcounter++;
+ }
+}
+
+void Cegis::addRefinementLemmaConjunct(unsigned wcounter,
+ std::vector<Node>& waiting)
+{
+ Node lem = waiting[wcounter];
+ lem = Rewriter::rewrite(lem);
+ // apply substitution and rewrite if applicable
+ if (lem.isConst())
+ {
+ if (!lem.getConst<bool>())
+ {
+ // conjecture is infeasible
+ }
+ else
+ {
+ return;
+ }
+ }
+ // break into conjunctions
+ if (lem.getKind() == AND)
+ {
+ for (const Node& lc : lem)
+ {
+ waiting.push_back(lc);
+ }
+ return;
+ }
+ // does this correspond to a substitution?
+ NodeManager* nm = NodeManager::currentNM();
+ TNode term;
+ TNode val;
+ if (lem.getKind() == EQUAL)
+ {
+ for (unsigned i = 0; i < 2; i++)
+ {
+ if (lem[i].isConst() && d_tds->isEvaluationPoint(lem[1 - i]))
+ {
+ term = lem[1 - i];
+ val = lem[i];
+ break;
+ }
+ }
+ }
+ else
+ {
+ term = lem.getKind() == NOT ? lem[0] : lem;
+ // predicate case: the conjunct is a (negated) evaluation point
+ if (d_tds->isEvaluationPoint(term))
+ {
+ val = nm->mkConst(lem.getKind() != NOT);
+ }
+ }
+ if (!val.isNull())
+ {
+ if (d_refinement_lemma_unit.find(lem) != d_refinement_lemma_unit.end())
+ {
+ // already added
+ return;
+ }
+ Trace("cegis-rl") << "* cegis-rl: propagate: " << term << " -> " << val
+ << std::endl;
+ d_rl_eval_hds.push_back(term);
+ d_rl_vals.push_back(val);
+ d_refinement_lemma_unit.insert(lem);
+ // apply to waiting lemmas beyond this one
+ for (unsigned i = wcounter + 1, size = waiting.size(); i < size; i++)
+ {
+ waiting[i] = waiting[i].substitute(term, val);
+ }
+ // apply to all existing refinement lemmas
+ std::vector<Node> to_rem;
+ for (const Node& rl : d_refinement_lemma_conj)
+ {
+ Node srl = rl.substitute(term, val);
+ if (srl != rl)
+ {
+ Trace("cegis-rl") << "* cegis-rl: replace: " << rl << " -> " << srl
+ << std::endl;
+ waiting.push_back(srl);
+ to_rem.push_back(rl);
+ }
+ }
+ for (const Node& tr : to_rem)
+ {
+ d_refinement_lemma_conj.erase(tr);
+ }
+ }
+ else
+ {
+ if (Trace.isOn("cegis-rl"))
+ {
+ if (d_refinement_lemma_conj.find(lem) == d_refinement_lemma_conj.end())
+ {
+ Trace("cegis-rl") << "cegis-rl: add: " << lem << std::endl;
+ }
+ }
+ d_refinement_lemma_conj.insert(lem);
+ }
+}
+
void Cegis::registerRefinementLemma(const std::vector<Node>& vars,
Node lem,
std::vector<Node>& lems)
{
- d_refinement_lemmas.push_back(lem);
+ addRefinementLemma(lem);
// Make the refinement lemma and add it to lems.
// This lemma is guarded by the parent's guard, which has the semantics
// "this conjecture has a solution", hence this lemma states:
@@ -168,118 +291,93 @@ void Cegis::getRefinementEvalLemmas(const std::vector<Node>& vs,
std::vector<Node>& lems)
{
Trace("sygus-cref-eval") << "Cref eval : conjecture has "
- << getNumRefinementLemmas() << " refinement lemmas."
+ << d_refinement_lemma_unit.size() << " unit and "
+ << d_refinement_lemma_conj.size()
+ << " non-unit refinement lemma conjunctions."
<< std::endl;
- unsigned nlemmas = getNumRefinementLemmas();
- if (nlemmas > 0 || options::cegisSample() != CEGIS_SAMPLE_NONE)
- {
- Assert(vs.size() == ms.size());
+ Assert(vs.size() == ms.size());
- TermDbSygus* tds = d_qe->getTermDatabaseSygus();
- NodeManager* nm = NodeManager::currentNM();
+ NodeManager* nm = NodeManager::currentNM();
- Node nfalse = nm->mkConst(false);
- Node neg_guard = d_parent->getGuard().negate();
- for (unsigned i = 0; i <= nlemmas; i++)
+ Node nfalse = nm->mkConst(false);
+ Node neg_guard = d_parent->getGuard().negate();
+ for (unsigned r = 0; r < 2; r++)
+ {
+ std::unordered_set<Node, NodeHashFunction>& rlemmas =
+ r == 0 ? d_refinement_lemma_unit : d_refinement_lemma_conj;
+ for (const Node& lem : rlemmas)
{
- if (i == nlemmas)
- {
- bool addedSample = false;
- // find a new one by sampling, if applicable
- if (options::cegisSample() != CEGIS_SAMPLE_NONE)
- {
- addedSample = sampleAddRefinementLemma(vs, ms, lems);
- }
- if (!addedSample)
- {
- return;
- }
- }
- Node lem;
+ Assert(!lem.isNull());
std::map<Node, Node> visited;
std::map<Node, std::vector<Node> > exp;
- lem = getRefinementLemma(i);
- if (!lem.isNull())
+ EvalSygusInvarianceTest vsit;
+ Trace("sygus-cref-eval") << "Check refinement lemma conjunct " << lem
+ << " against current model." << std::endl;
+ Trace("sygus-cref-eval2") << "Check refinement lemma conjunct " << lem
+ << " against current model." << std::endl;
+ Node cre_lem;
+ Node lemcs = lem.substitute(vs.begin(), vs.end(), ms.begin(), ms.end());
+ Trace("sygus-cref-eval2")
+ << "...under substitution it is : " << lemcs << std::endl;
+ Node lemcsu = vsit.doEvaluateWithUnfolding(d_tds, lemcs);
+ Trace("sygus-cref-eval2")
+ << "...after unfolding is : " << lemcsu << std::endl;
+ if (lemcsu.isConst() && !lemcsu.getConst<bool>())
{
- std::vector<Node> lem_conj;
- // break into conjunctions
- if (lem.getKind() == kind::AND)
+ std::vector<Node> msu;
+ std::vector<Node> mexp;
+ msu.insert(msu.end(), ms.begin(), ms.end());
+ std::map<TypeNode, int> var_count;
+ for (unsigned k = 0; k < vs.size(); k++)
{
- for (unsigned i = 0; i < lem.getNumChildren(); i++)
- {
- lem_conj.push_back(lem[i]);
- }
+ vsit.setUpdatedTerm(msu[k]);
+ msu[k] = vs[k];
+ // substitute for everything except this
+ Node sconj =
+ lem.substitute(vs.begin(), vs.end(), msu.begin(), msu.end());
+ vsit.init(sconj, vs[k], nfalse);
+ // get minimal explanation for this
+ Node ut = vsit.getUpdatedTerm();
+ Trace("sygus-cref-eval2-debug")
+ << " compute min explain of : " << vs[k] << " = " << ut
+ << std::endl;
+ d_tds->getExplain()->getExplanationFor(
+ vs[k], ut, mexp, vsit, var_count, false);
+ Trace("sygus-cref-eval2-debug") << "exp now: " << mexp << std::endl;
+ msu[k] = vsit.getUpdatedTerm();
+ Trace("sygus-cref-eval2-debug")
+ << "updated term : " << msu[k] << std::endl;
}
- else
+ if (!mexp.empty())
{
- lem_conj.push_back(lem);
+ Node en = mexp.size() == 1 ? mexp[0] : nm->mkNode(kind::AND, mexp);
+ cre_lem = nm->mkNode(kind::OR, en.negate(), neg_guard);
}
- EvalSygusInvarianceTest vsit;
- for (unsigned j = 0; j < lem_conj.size(); j++)
+ else
{
- Node lemc = lem_conj[j];
- Trace("sygus-cref-eval") << "Check refinement lemma conjunct " << lemc
- << " against current model." << std::endl;
- Trace("sygus-cref-eval2") << "Check refinement lemma conjunct "
- << lemc << " against current model."
- << std::endl;
- Node cre_lem;
- Node lemcs =
- lemc.substitute(vs.begin(), vs.end(), ms.begin(), ms.end());
- Trace("sygus-cref-eval2") << "...under substitution it is : " << lemcs
- << std::endl;
- Node lemcsu = vsit.doEvaluateWithUnfolding(tds, lemcs);
- Trace("sygus-cref-eval2") << "...after unfolding is : " << lemcsu
- << std::endl;
- if (lemcsu.isConst() && !lemcsu.getConst<bool>())
- {
- std::vector<Node> msu;
- std::vector<Node> mexp;
- msu.insert(msu.end(), ms.begin(), ms.end());
- std::map<TypeNode, int> var_count;
- for (unsigned k = 0; k < vs.size(); k++)
- {
- vsit.setUpdatedTerm(msu[k]);
- msu[k] = vs[k];
- // substitute for everything except this
- Node sconj =
- lemc.substitute(vs.begin(), vs.end(), msu.begin(), msu.end());
- vsit.init(sconj, vs[k], nfalse);
- // get minimal explanation for this
- Node ut = vsit.getUpdatedTerm();
- Trace("sygus-cref-eval2-debug")
- << " compute min explain of : " << vs[k] << " = " << ut
- << std::endl;
- tds->getExplain()->getExplanationFor(
- vs[k], ut, mexp, vsit, var_count, false);
- Trace("sygus-cref-eval2-debug")
- << "exp now: " << mexp << std::endl;
- msu[k] = vsit.getUpdatedTerm();
- Trace("sygus-cref-eval2-debug")
- << "updated term : " << msu[k] << std::endl;
- }
- if (!mexp.empty())
- {
- Node en =
- mexp.size() == 1 ? mexp[0] : nm->mkNode(kind::AND, mexp);
- cre_lem = nm->mkNode(kind::OR, en.negate(), neg_guard);
- }
- else
- {
- cre_lem = neg_guard;
- }
- }
- if (!cre_lem.isNull())
- {
- if (std::find(lems.begin(), lems.end(), cre_lem) == lems.end())
- {
- Trace("sygus-cref-eval") << "...produced lemma : " << cre_lem
- << std::endl;
- lems.push_back(cre_lem);
- }
- }
+ cre_lem = neg_guard;
}
}
+ if (!cre_lem.isNull()
+ && std::find(lems.begin(), lems.end(), cre_lem) == lems.end())
+ {
+ Trace("sygus-cref-eval")
+ << "...produced lemma : " << cre_lem << std::endl;
+ lems.push_back(cre_lem);
+ }
+ }
+ if (!lems.empty())
+ {
+ break;
+ }
+ }
+ // if we didn't add a lemma, trying sampling to add one
+ if (options::cegisSample() != CEGIS_SAMPLE_NONE && lems.empty())
+ {
+ if (sampleAddRefinementLemma(vs, ms, lems))
+ {
+ // restart (should be guaranteed to add evaluation lemmas
+ getRefinementEvalLemmas(vs, ms, lems);
}
}
}
@@ -344,7 +442,7 @@ bool Cegis::sampleAddRefinementLemma(const std::vector<Node>& candidates,
Trace("cegis-sample") << std::endl;
}
Trace("cegqi-engine") << " *** Refine by sampling" << std::endl;
- d_refinement_lemmas.push_back(rlem);
+ addRefinementLemma(rlem);
// if trust, we are not interested in sending out refinement lemmas
if (options::cegisSample() != CEGIS_SAMPLE_TRUST)
{
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback