From d6b3329e3f2b6e29e5f4af6cf09fd32e26c47e15 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 8 Sep 2021 12:30:31 -0500 Subject: Towards standard usage of ExtendedRewriter (#7145) This PR: Adds extendedRewrite to EnvObj and Rewriter. Eliminates static calls to Rewriter::rewrite from within the extended rewriter. Instead, the use of extended rewriter is always through Rewriter, which passes itself to the ExtendedRewriter. Make most uses of extended rewriter non-static. I've added a placeholder method Rewriter::callExtendedRewrite for places in the code that call the extended rewriter are currently difficult to eliminate. --- .../quantifiers/candidate_rewrite_database.cpp | 16 ++++++++-------- src/theory/quantifiers/candidate_rewrite_database.h | 8 ++++---- src/theory/quantifiers/expr_miner_manager.cpp | 2 +- src/theory/quantifiers/expr_miner_manager.h | 6 ------ src/theory/quantifiers/extended_rewrite.cpp | 21 +++++++++++---------- src/theory/quantifiers/extended_rewrite.h | 7 ++++++- src/theory/quantifiers/quantifiers_modules.cpp | 5 +++-- src/theory/quantifiers/quantifiers_modules.h | 3 ++- src/theory/quantifiers/quantifiers_rewriter.cpp | 3 +-- .../quantifiers/sygus/ce_guided_single_inv.cpp | 4 ++-- src/theory/quantifiers/sygus/cegis.cpp | 2 +- .../quantifiers/sygus/enum_stream_substitution.cpp | 18 ++++++++---------- src/theory/quantifiers/sygus/enum_value_manager.cpp | 6 ++++-- src/theory/quantifiers/sygus/enum_value_manager.h | 6 ++++-- .../quantifiers/sygus/sygus_enumerator_basic.cpp | 3 ++- .../quantifiers/sygus/sygus_enumerator_callback.cpp | 3 ++- .../quantifiers/sygus/sygus_enumerator_callback.h | 2 -- src/theory/quantifiers/sygus/sygus_grammar_red.cpp | 3 ++- src/theory/quantifiers/sygus/sygus_invariance.cpp | 6 +++--- src/theory/quantifiers/sygus/sygus_pbe.cpp | 2 +- src/theory/quantifiers/sygus/sygus_unif_io.cpp | 2 +- src/theory/quantifiers/sygus/synth_conjecture.cpp | 19 ++++++++++--------- src/theory/quantifiers/sygus/synth_conjecture.h | 6 ++++-- src/theory/quantifiers/sygus/synth_engine.cpp | 9 +++++---- src/theory/quantifiers/sygus/synth_engine.h | 3 ++- .../quantifiers/sygus/term_database_sygus.cpp | 8 ++++---- src/theory/quantifiers/sygus/term_database_sygus.h | 11 ++++------- src/theory/quantifiers/term_registry.cpp | 6 ++++-- src/theory/quantifiers/term_registry.h | 3 +-- src/theory/quantifiers/theory_quantifiers.cpp | 4 ++-- 30 files changed, 102 insertions(+), 95 deletions(-) (limited to 'src/theory/quantifiers') diff --git a/src/theory/quantifiers/candidate_rewrite_database.cpp b/src/theory/quantifiers/candidate_rewrite_database.cpp index 0fd0eebd6..475df0b43 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.cpp +++ b/src/theory/quantifiers/candidate_rewrite_database.cpp @@ -37,7 +37,7 @@ CandidateRewriteDatabase::CandidateRewriteDatabase( Env& env, bool doCheck, bool rewAccel, bool silent, bool filterPairs) : ExprMiner(env), d_tds(nullptr), - d_ext_rewrite(nullptr), + d_useExtRewriter(false), d_doCheck(doCheck), d_rewAccel(rewAccel), d_silent(silent), @@ -52,7 +52,7 @@ void CandidateRewriteDatabase::initialize(const std::vector& vars, d_candidate = Node::null(); d_using_sygus = false; d_tds = nullptr; - d_ext_rewrite = nullptr; + d_useExtRewriter = false; if (d_filterPairs) { d_crewrite_filter.initialize(ss, nullptr, false); @@ -69,7 +69,7 @@ void CandidateRewriteDatabase::initializeSygus(const std::vector& vars, d_candidate = f; d_using_sygus = true; d_tds = tds; - d_ext_rewrite = nullptr; + d_useExtRewriter = false; if (d_filterPairs) { d_crewrite_filter.initialize(ss, d_tds, d_using_sygus); @@ -121,10 +121,10 @@ Node CandidateRewriteDatabase::addTerm(Node sol, // get the rewritten form Node solbr; Node eq_solr; - if (d_ext_rewrite != nullptr) + if (d_useExtRewriter) { - solbr = d_ext_rewrite->extendedRewrite(solb); - eq_solr = d_ext_rewrite->extendedRewrite(eq_solb); + solbr = extendedRewrite(solb); + eq_solr = extendedRewrite(eq_solb); } else { @@ -289,9 +289,9 @@ bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out) void CandidateRewriteDatabase::setSilent(bool flag) { d_silent = flag; } -void CandidateRewriteDatabase::setExtendedRewriter(ExtendedRewriter* er) +void CandidateRewriteDatabase::enableExtendedRewriter() { - d_ext_rewrite = er; + d_useExtRewriter = true; } } // namespace quantifiers diff --git a/src/theory/quantifiers/candidate_rewrite_database.h b/src/theory/quantifiers/candidate_rewrite_database.h index 71ae5649f..c0e783fc1 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.h +++ b/src/theory/quantifiers/candidate_rewrite_database.h @@ -100,14 +100,14 @@ class CandidateRewriteDatabase : public ExprMiner bool addTerm(Node sol, std::ostream& out) override; /** sets whether this class should output candidate rewrites it finds */ void setSilent(bool flag); - /** set the (extended) rewriter used by this class */ - void setExtendedRewriter(ExtendedRewriter* er); + /** Enable the (extended) rewriter for this class */ + void enableExtendedRewriter(); private: /** (required) pointer to the sygus term database of d_qe */ TermDbSygus* d_tds; - /** an extended rewriter object */ - ExtendedRewriter* d_ext_rewrite; + /** Whether we use the extended rewriter */ + bool d_useExtRewriter; /** the function-to-synthesize we are testing (if sygus) */ Node d_candidate; /** whether we are checking equivalence using subsolver */ diff --git a/src/theory/quantifiers/expr_miner_manager.cpp b/src/theory/quantifiers/expr_miner_manager.cpp index ae20d4909..8af456ea8 100644 --- a/src/theory/quantifiers/expr_miner_manager.cpp +++ b/src/theory/quantifiers/expr_miner_manager.cpp @@ -87,7 +87,7 @@ void ExpressionMinerManager::enableRewriteRuleSynth() { d_crd.initialize(vars, &d_sampler); } - d_crd.setExtendedRewriter(&d_ext_rew); + d_crd.enableExtendedRewriter(); d_crd.setSilent(false); } diff --git a/src/theory/quantifiers/expr_miner_manager.h b/src/theory/quantifiers/expr_miner_manager.h index 92450b3ba..43a615c97 100644 --- a/src/theory/quantifiers/expr_miner_manager.h +++ b/src/theory/quantifiers/expr_miner_manager.h @@ -21,15 +21,11 @@ #include "expr/node.h" #include "smt/env_obj.h" #include "theory/quantifiers/candidate_rewrite_database.h" -#include "theory/quantifiers/extended_rewrite.h" #include "theory/quantifiers/query_generator.h" #include "theory/quantifiers/solution_filter.h" #include "theory/quantifiers/sygus_sampler.h" namespace cvc5 { - -class Env; - namespace theory { namespace quantifiers { @@ -114,8 +110,6 @@ class ExpressionMinerManager : protected EnvObj SolutionFilterStrength d_sols; /** sygus sampler object */ SygusSampler d_sampler; - /** extended rewriter object */ - ExtendedRewriter d_ext_rew; }; } // namespace quantifiers diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index 58a78b4aa..40e28eb78 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -42,7 +42,8 @@ struct ExtRewriteAggAttributeId }; typedef expr::Attribute ExtRewriteAggAttribute; -ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr) +ExtendedRewriter::ExtendedRewriter(Rewriter& rew, bool aggr) + : d_rew(rew), d_aggr(aggr) { d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); @@ -97,7 +98,7 @@ bool ExtendedRewriter::addToChildren(Node nc, Node ExtendedRewriter::extendedRewrite(Node n) const { - n = Rewriter::rewrite(n); + n = d_rew.rewrite(n); // has it already been computed? Node ncache = getCache(n); @@ -204,7 +205,7 @@ Node ExtendedRewriter::extendedRewrite(Node n) const } } } - ret = Rewriter::rewrite(ret); + ret = d_rew.rewrite(ret); //--------------------end rewrite children // now, do extended rewrite @@ -496,7 +497,7 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); if (nn != t2) { - nn = Rewriter::rewrite(nn); + nn = d_rew.rewrite(nn); if (nn == t1) { new_ret = t2; @@ -508,7 +509,7 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const // must use partial substitute here, to avoid substitution into witness std::map rkinds; nn = partialSubstitute(t1, vars, subs, rkinds); - nn = Rewriter::rewrite(nn); + nn = d_rew.rewrite(nn); if (nn != t1) { // If full=false, then we've duplicated a term u in the children of n. @@ -537,7 +538,7 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const Node nn = partialSubstitute(t2, assign, rkinds); if (nn != t2) { - nn = Rewriter::rewrite(nn); + nn = d_rew.rewrite(nn); if (nn == t1) { new_ret = nn; @@ -625,7 +626,7 @@ Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) const { children[ii] = n[i][j + 1]; Node pull = nm->mkNode(n.getKind(), children); - Node pullr = Rewriter::rewrite(pull); + Node pullr = d_rew.rewrite(pull); children[ii] = n[i]; ite_c[i][j] = pullr; } @@ -688,7 +689,7 @@ Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) const Assert(nite.getKind() == itek); // now, simply pull the ITE and try ITE rewrites Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]); - pull_ite = Rewriter::rewrite(pull_ite); + pull_ite = d_rew.rewrite(pull_ite); if (pull_ite.getKind() == ITE) { Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false); @@ -887,7 +888,7 @@ Node ExtendedRewriter::extendedRewriteBcp(Kind andk, ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs); Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs << std::endl; - ccs = Rewriter::rewrite(ccs); + ccs = d_rew.rewrite(ccs); Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl; to_process.push_back(ccs); // store this as a node that propagation touched. This marks c so that @@ -1522,7 +1523,7 @@ Node ExtendedRewriter::extendedRewriteEqChain( index--; new_ret = nm->mkNode(eqk, children[index], new_ret); } - new_ret = Rewriter::rewrite(new_ret); + new_ret = d_rew.rewrite(new_ret); if (new_ret != ret) { return new_ret; diff --git a/src/theory/quantifiers/extended_rewrite.h b/src/theory/quantifiers/extended_rewrite.h index b1b08657d..b4dcab041 100644 --- a/src/theory/quantifiers/extended_rewrite.h +++ b/src/theory/quantifiers/extended_rewrite.h @@ -24,6 +24,9 @@ namespace cvc5 { namespace theory { + +class Rewriter; + namespace quantifiers { /** Extended rewriter @@ -48,12 +51,14 @@ namespace quantifiers { class ExtendedRewriter { public: - ExtendedRewriter(bool aggr = true); + ExtendedRewriter(Rewriter& rew, bool aggr = true); ~ExtendedRewriter() {} /** return the extended rewritten form of n */ Node extendedRewrite(Node n) const; private: + /** The underlying rewriter that we are extending */ + Rewriter& d_rew; /** cache that the extended rewritten form of n is ret */ void setCache(Node n, Node ret) const; /** get the cache for n */ diff --git a/src/theory/quantifiers/quantifiers_modules.cpp b/src/theory/quantifiers/quantifiers_modules.cpp index 27ec187a9..6cfc48fb9 100644 --- a/src/theory/quantifiers/quantifiers_modules.cpp +++ b/src/theory/quantifiers/quantifiers_modules.cpp @@ -41,7 +41,8 @@ QuantifiersModules::QuantifiersModules() { } QuantifiersModules::~QuantifiersModules() {} -void QuantifiersModules::initialize(QuantifiersState& qs, +void QuantifiersModules::initialize(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr, @@ -72,7 +73,7 @@ void QuantifiersModules::initialize(QuantifiersState& qs, } if (options::sygus()) { - d_synth_e.reset(new SynthEngine(qs, qim, qr, tr)); + d_synth_e.reset(new SynthEngine(env, qs, qim, qr, tr)); modules.push_back(d_synth_e.get()); } // bounded integer instantiation is used when the user requests it via diff --git a/src/theory/quantifiers/quantifiers_modules.h b/src/theory/quantifiers/quantifiers_modules.h index f41e81f34..9878e79ae 100644 --- a/src/theory/quantifiers/quantifiers_modules.h +++ b/src/theory/quantifiers/quantifiers_modules.h @@ -57,7 +57,8 @@ class QuantifiersModules * This constructs the above modules based on the current options. It adds * a pointer to each module it constructs to modules. */ - void initialize(QuantifiersState& qs, + void initialize(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr, diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp index 6d8570287..e5662cdc6 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.cpp +++ b/src/theory/quantifiers/quantifiers_rewriter.cpp @@ -548,8 +548,7 @@ Node QuantifiersRewriter::computeExtendedRewrite(Node q) { Node body = q[1]; // apply extended rewriter - ExtendedRewriter er; - Node bodyr = er.extendedRewrite(body); + Node bodyr = Rewriter::callExtendedRewrite(body); if (body != bodyr) { std::vector children; diff --git a/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp b/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp index d2c616238..80f4af984 100644 --- a/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp +++ b/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp @@ -400,7 +400,7 @@ Node CegSingleInv::getSolutionFromInst(size_t index) } //simplify the solution using the extended rewriter Trace("csi-sol") << "Solution (pre-simplification): " << s << std::endl; - s = d_treg.getTermDatabaseSygus()->getExtRewriter()->extendedRewrite(s); + s = extendedRewrite(s); Trace("csi-sol") << "Solution (post-simplification): " << s << std::endl; // wrap into lambda, as needed return SygusUtils::wrapSolutionForSynthFun(prog, s); @@ -467,7 +467,7 @@ Node CegSingleInv::reconstructToSyntax(Node s, { Trace("csi-sol") << "Post-process solution..." << std::endl; Node prev = sol; - sol = d_treg.getTermDatabaseSygus()->getExtRewriter()->extendedRewrite(sol); + sol = extendedRewrite(sol); if (prev != sol) { Trace("csi-sol") << "Solution (after post process) : " << sol diff --git a/src/theory/quantifiers/sygus/cegis.cpp b/src/theory/quantifiers/sygus/cegis.cpp index 57b763044..8d1bfd9b6 100644 --- a/src/theory/quantifiers/sygus/cegis.cpp +++ b/src/theory/quantifiers/sygus/cegis.cpp @@ -345,7 +345,7 @@ void Cegis::addRefinementLemma(Node lem) d_rl_vals.end()); } // rewrite with extended rewriter - slem = d_tds->getExtRewriter()->extendedRewrite(slem); + slem = extendedRewrite(slem); // collect all variables in slem expr::getSymbols(slem, d_refinement_lemma_vars); std::vector waiting; diff --git a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp index f853ac8e8..a5be4ebd6 100644 --- a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp +++ b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp @@ -16,15 +16,16 @@ #include "theory/quantifiers/sygus/enum_stream_substitution.h" +#include // for std::iota +#include + #include "expr/dtype_cons.h" #include "options/base_options.h" #include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "printer/printer.h" #include "theory/quantifiers/sygus/term_database_sygus.h" - -#include // for std::iota -#include +#include "theory/rewriter.h" using namespace cvc5::kind; @@ -32,7 +33,7 @@ namespace cvc5 { namespace theory { namespace quantifiers { -EnumStreamPermutation::EnumStreamPermutation(quantifiers::TermDbSygus* tds) +EnumStreamPermutation::EnumStreamPermutation(TermDbSygus* tds) : d_tds(tds), d_first(true), d_curr_ind(0) { } @@ -124,8 +125,7 @@ Node EnumStreamPermutation::getNext() { d_first = false; Node bultin_value = d_tds->sygusToBuiltin(d_value, d_value.getType()); - d_perm_values.insert( - d_tds->getExtRewriter()->extendedRewrite(bultin_value)); + d_perm_values.insert(Rewriter::callExtendedRewrite(bultin_value)); return d_value; } unsigned n_classes = d_perm_state_class.size(); @@ -194,8 +194,7 @@ Node EnumStreamPermutation::getNext() << " ......perm builtin is " << bultin_perm_value; if (options::sygusSymBreakDynamic()) { - bultin_perm_value = - d_tds->getExtRewriter()->extendedRewrite(bultin_perm_value); + bultin_perm_value = Rewriter::callExtendedRewrite(bultin_perm_value); Trace("synth-stream-concrete-debug") << " and rewrites to " << bultin_perm_value; } @@ -515,8 +514,7 @@ Node EnumStreamSubstitution::getNext() d_tds->sygusToBuiltin(comb_value, comb_value.getType()); if (options::sygusSymBreakDynamic()) { - builtin_comb_value = - d_tds->getExtRewriter()->extendedRewrite(builtin_comb_value); + builtin_comb_value = Rewriter::callExtendedRewrite(builtin_comb_value); } if (Trace.isOn("synth-stream-concrete")) { diff --git a/src/theory/quantifiers/sygus/enum_value_manager.cpp b/src/theory/quantifiers/sygus/enum_value_manager.cpp index 8a2d70bfa..1d0ba5bee 100644 --- a/src/theory/quantifiers/sygus/enum_value_manager.cpp +++ b/src/theory/quantifiers/sygus/enum_value_manager.cpp @@ -33,13 +33,15 @@ namespace cvc5 { namespace theory { namespace quantifiers { -EnumValueManager::EnumValueManager(Node e, +EnumValueManager::EnumValueManager(Env& env, QuantifiersState& qs, QuantifiersInferenceManager& qim, TermRegistry& tr, SygusStatistics& s, + Node e, bool hasExamples) - : d_enum(e), + : EnvObj(env), + d_enum(e), d_qstate(qs), d_qim(qim), d_treg(tr), diff --git a/src/theory/quantifiers/sygus/enum_value_manager.h b/src/theory/quantifiers/sygus/enum_value_manager.h index c786bb6f1..23fdc7391 100644 --- a/src/theory/quantifiers/sygus/enum_value_manager.h +++ b/src/theory/quantifiers/sygus/enum_value_manager.h @@ -19,6 +19,7 @@ #define CVC5__THEORY__QUANTIFIERS__SYGUS__ENUM_VALUE_MANAGER_H #include "expr/node.h" +#include "smt/env_obj.h" #include "theory/quantifiers/sygus/enum_val_generator.h" #include "theory/quantifiers/sygus/example_eval_cache.h" #include "theory/quantifiers/sygus/sygus_enumerator_callback.h" @@ -38,14 +39,15 @@ class SygusStatistics; * not actively generated, or may be determined by the (fast) enumerator * when it is actively generated. */ -class EnumValueManager +class EnumValueManager : protected EnvObj { public: - EnumValueManager(Node e, + EnumValueManager(Env& env, QuantifiersState& qs, QuantifiersInferenceManager& qim, TermRegistry& tr, SygusStatistics& s, + Node e, bool hasExamples); ~EnumValueManager(); /** diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp b/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp index f45b976ec..743f67cec 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp @@ -15,6 +15,7 @@ #include "theory/quantifiers/sygus/sygus_enumerator_basic.h" #include "options/datatypes_options.h" +#include "theory/rewriter.h" using namespace cvc5::kind; using namespace std; @@ -40,7 +41,7 @@ bool EnumValGeneratorBasic::increment() if (options::sygusSymBreakDynamic()) { Node nextb = d_tds->sygusToBuiltin(d_currTerm); - nextb = d_tds->getExtRewriter()->extendedRewrite(nextb); + nextb = Rewriter::callExtendedRewrite(nextb); if (d_cache.find(nextb) == d_cache.end()) { d_cache.insert(nextb); diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp b/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp index 3b536695f..1b5b3f5af 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp @@ -19,6 +19,7 @@ #include "theory/quantifiers/sygus/example_eval_cache.h" #include "theory/quantifiers/sygus/sygus_stats.h" #include "theory/quantifiers/sygus_sampler.h" +#include "theory/rewriter.h" namespace cvc5 { namespace theory { @@ -33,7 +34,7 @@ SygusEnumeratorCallback::SygusEnumeratorCallback(Node e, SygusStatistics* s) bool SygusEnumeratorCallback::addTerm(Node n, std::unordered_set& bterms) { Node bn = datatypes::utils::sygusToBuiltin(n); - Node bnr = d_extr.extendedRewrite(bn); + Node bnr = Rewriter::callExtendedRewrite(bn); if (d_stats != nullptr) { ++(d_stats->d_enumTermsRewrite); diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_callback.h b/src/theory/quantifiers/sygus/sygus_enumerator_callback.h index 5ed28b309..8689d876f 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_callback.h +++ b/src/theory/quantifiers/sygus/sygus_enumerator_callback.h @@ -74,8 +74,6 @@ class SygusEnumeratorCallback Node d_enum; /** The type of enum */ TypeNode d_tn; - /** extended rewriter */ - ExtendedRewriter d_extr; /** pointer to the statistics */ SygusStatistics* d_stats; }; diff --git a/src/theory/quantifiers/sygus/sygus_grammar_red.cpp b/src/theory/quantifiers/sygus/sygus_grammar_red.cpp index a51fcce25..fd84f0c0a 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_red.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_red.cpp @@ -21,6 +21,7 @@ #include "options/quantifiers_options.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" +#include "theory/rewriter.h" using namespace std; using namespace cvc5::kind; @@ -147,7 +148,7 @@ void SygusRedundantCons::getGenericList(TermDbSygus* tds, if (index == dt[c].getNumArgs()) { Node gt = tds->mkGeneric(dt, c, pre); - gt = tds->getExtRewriter()->extendedRewrite(gt); + gt = Rewriter::callExtendedRewrite(gt); terms.push_back(gt); return; } diff --git a/src/theory/quantifiers/sygus/sygus_invariance.cpp b/src/theory/quantifiers/sygus/sygus_invariance.cpp index cb7e2b84e..29557fe5c 100644 --- a/src/theory/quantifiers/sygus/sygus_invariance.cpp +++ b/src/theory/quantifiers/sygus/sygus_invariance.cpp @@ -111,7 +111,7 @@ bool EquivSygusInvarianceTest::invariant(TermDbSygus* tds, Node nvn, Node x) { TypeNode tn = nvn.getType(); Node nbv = tds->sygusToBuiltin(nvn, tn); - Node nbvr = tds->getExtRewriter()->extendedRewrite(nbv); + Node nbvr = Rewriter::callExtendedRewrite(nbv); Trace("sygus-sb-mexp-debug") << " min-exp check : " << nbv << " -> " << nbvr << std::endl; bool exc_arg = false; @@ -181,7 +181,7 @@ bool DivByZeroSygusInvarianceTest::invariant(TermDbSygus* tds, Node nvn, Node x) { TypeNode tn = nvn.getType(); Node nbv = tds->sygusToBuiltin(nvn, tn); - Node nbvr = tds->getExtRewriter()->extendedRewrite(nbv); + Node nbvr = Rewriter::callExtendedRewrite(nbv); if (tds->involvesDivByZero(nbvr)) { Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin(nvn) @@ -212,7 +212,7 @@ bool NegContainsSygusInvarianceTest::invariant(TermDbSygus* tds, { TypeNode tn = nvn.getType(); Node nbv = tds->sygusToBuiltin(nvn, tn); - Node nbvr = tds->getExtRewriter()->extendedRewrite(nbv); + Node nbvr = Rewriter::callExtendedRewrite(nbv); // if for any of the examples, it is not contained, then we can exclude for (unsigned i = 0; i < d_neg_con_indices.size(); i++) { diff --git a/src/theory/quantifiers/sygus/sygus_pbe.cpp b/src/theory/quantifiers/sygus/sygus_pbe.cpp index 7601e2117..52bca1586 100644 --- a/src/theory/quantifiers/sygus/sygus_pbe.cpp +++ b/src/theory/quantifiers/sygus/sygus_pbe.cpp @@ -131,7 +131,7 @@ bool SygusPbe::initialize(Node conj, // Apply extended rewriting on the lemma. This helps utilities like // SygusEnumerator more easily recognize the shape of this lemma, e.g. // ( ~is-ite(x) or ( ~is-ite(x) ^ P ) ) --> ~is-ite(x). - lem = d_tds->getExtRewriter()->extendedRewrite(lem); + lem = extendedRewrite(lem); Trace("sygus-pbe") << " static redundant op lemma : " << lem << std::endl; // Register as a symmetry breaking lemma with the term database. diff --git a/src/theory/quantifiers/sygus/sygus_unif_io.cpp b/src/theory/quantifiers/sygus/sygus_unif_io.cpp index 9626f7af4..3fb80f917 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_io.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_io.cpp @@ -569,7 +569,7 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector& lemmas) std::vector base_results; TypeNode xtn = e.getType(); Node bv = d_tds->sygusToBuiltin(v, xtn); - bv = d_tds->getExtRewriter()->extendedRewrite(bv); + bv = extendedRewrite(bv); Trace("sygus-sui-enum") << "PBE Compute Examples for " << bv << std::endl; // compte the results (should be cached) ExampleEvalCache* eec = d_parent->getExampleEvalCache(e); diff --git a/src/theory/quantifiers/sygus/synth_conjecture.cpp b/src/theory/quantifiers/sygus/synth_conjecture.cpp index 3e7095c12..e87857c3b 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.cpp +++ b/src/theory/quantifiers/sygus/synth_conjecture.cpp @@ -45,12 +45,14 @@ namespace cvc5 { namespace theory { namespace quantifiers { -SynthConjecture::SynthConjecture(QuantifiersState& qs, +SynthConjecture::SynthConjecture(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr, SygusStatistics& s) - : d_qstate(qs), + : EnvObj(env), + d_qstate(qs), d_qim(qim), d_qreg(qr), d_treg(tr), @@ -58,11 +60,11 @@ SynthConjecture::SynthConjecture(QuantifiersState& qs, d_tds(tr.getTermDatabaseSygus()), d_verify(qs.options(), qs.getLogicInfo(), d_tds), d_hasSolution(false), - d_ceg_si(new CegSingleInv(qs.getEnv(), tr, s)), + d_ceg_si(new CegSingleInv(env, tr, s)), d_templInfer(new SygusTemplateInfer), d_ceg_proc(new SynthConjectureProcess), d_ceg_gc(new CegGrammarConstructor(d_tds, this)), - d_sygus_rconst(new SygusRepairConst(qs.getEnv(), d_tds)), + d_sygus_rconst(new SygusRepairConst(env, d_tds)), d_exampleInfer(new ExampleInfer(d_tds)), d_ceg_pbe(new SygusPbe(qs, qim, d_tds, this)), d_ceg_cegis(new Cegis(qs, qim, d_tds, this)), @@ -609,8 +611,7 @@ bool SynthConjecture::checkSideCondition(const std::vector& cvals) const } Trace("sygus-engine") << "Check side condition..." << std::endl; Trace("cegqi-debug") << "Check side condition : " << sc << std::endl; - Env& env = d_qstate.getEnv(); - Result r = checkWithSubsolver(sc, env.getOptions(), env.getLogicInfo()); + Result r = checkWithSubsolver(sc, options(), logicInfo()); Trace("cegqi-debug") << "...got side condition : " << r << std::endl; if (r == Result::UNSAT) { @@ -763,8 +764,8 @@ EnumValueManager* SynthConjecture::getEnumValueManagerFor(Node e) Node f = d_tds->getSynthFunForEnumerator(e); bool hasExamples = (d_exampleInfer->hasExamples(f) && d_exampleInfer->getNumExamples(f) != 0); - d_enumManager[e].reset( - new EnumValueManager(e, d_qstate, d_qim, d_treg, d_stats, hasExamples)); + d_enumManager[e].reset(new EnumValueManager( + d_env, d_qstate, d_qim, d_treg, d_stats, e, hasExamples)); EnumValueManager* eman = d_enumManager[e].get(); // set up the examples if (hasExamples) @@ -885,7 +886,7 @@ void SynthConjecture::printSynthSolutionInternal(std::ostream& out) d_exprm.find(prog); if (its == d_exprm.end()) { - d_exprm[prog].reset(new ExpressionMinerManager(d_qstate.getEnv())); + d_exprm[prog].reset(new ExpressionMinerManager(d_env)); ExpressionMinerManager* emm = d_exprm[prog].get(); emm->initializeSygus( d_tds, d_candidates[i], options::sygusSamples(), true); diff --git a/src/theory/quantifiers/sygus/synth_conjecture.h b/src/theory/quantifiers/sygus/synth_conjecture.h index 9cc488fd2..d7635c816 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.h +++ b/src/theory/quantifiers/sygus/synth_conjecture.h @@ -21,6 +21,7 @@ #include +#include "smt/env_obj.h" #include "theory/quantifiers/expr_miner_manager.h" #include "theory/quantifiers/sygus/ce_guided_single_inv.h" #include "theory/quantifiers/sygus/cegis.h" @@ -51,10 +52,11 @@ class EnumValueManager; * determines which approach and optimizations are applicable to the * conjecture, and has interfaces for implementing them. */ -class SynthConjecture +class SynthConjecture : protected EnvObj { public: - SynthConjecture(QuantifiersState& qs, + SynthConjecture(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr, diff --git a/src/theory/quantifiers/sygus/synth_engine.cpp b/src/theory/quantifiers/sygus/synth_engine.cpp index cdcbeb85d..64227793d 100644 --- a/src/theory/quantifiers/sygus/synth_engine.cpp +++ b/src/theory/quantifiers/sygus/synth_engine.cpp @@ -26,14 +26,15 @@ namespace cvc5 { namespace theory { namespace quantifiers { -SynthEngine::SynthEngine(QuantifiersState& qs, +SynthEngine::SynthEngine(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr) : QuantifiersModule(qs, qim, qr, tr), d_conj(nullptr), d_sqp(qs.getEnv()) { d_conjs.push_back(std::unique_ptr( - new SynthConjecture(qs, qim, qr, tr, d_statistics))); + new SynthConjecture(env, qs, qim, qr, tr, d_statistics))); d_conj = d_conjs.back().get(); } @@ -153,8 +154,8 @@ void SynthEngine::assignConjecture(Node q) // allocate a new synthesis conjecture if not assigned if (d_conjs.back()->isAssigned()) { - d_conjs.push_back(std::unique_ptr( - new SynthConjecture(d_qstate, d_qim, d_qreg, d_treg, d_statistics))); + d_conjs.push_back(std::unique_ptr(new SynthConjecture( + d_env, d_qstate, d_qim, d_qreg, d_treg, d_statistics))); } d_conjs.back()->assign(q); } diff --git a/src/theory/quantifiers/sygus/synth_engine.h b/src/theory/quantifiers/sygus/synth_engine.h index d37df4e28..c623d9c0f 100644 --- a/src/theory/quantifiers/sygus/synth_engine.h +++ b/src/theory/quantifiers/sygus/synth_engine.h @@ -34,7 +34,8 @@ class SynthEngine : public QuantifiersModule typedef context::CDHashMap NodeBoolMap; public: - SynthEngine(QuantifiersState& qs, + SynthEngine(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr); diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index 3b0ea3312..9c9a90255 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -51,10 +51,10 @@ std::ostream& operator<<(std::ostream& os, EnumeratorRole r) return os; } -TermDbSygus::TermDbSygus(QuantifiersState& qs) - : d_qstate(qs), +TermDbSygus::TermDbSygus(Env& env, QuantifiersState& qs) + : EnvObj(env), + d_qstate(qs), d_syexp(new SygusExplain(this)), - d_ext_rw(new ExtendedRewriter(true)), d_eval(new Evaluator), d_funDefEval(new FunDefEvaluator), d_eval_unfold(new SygusEvalUnfold(this)) @@ -1036,7 +1036,7 @@ Node TermDbSygus::evaluateWithUnfolding(Node n, } if (options::sygusExtRew()) { - ret = getExtRewriter()->extendedRewrite(ret); + ret = extendedRewrite(ret); } // use rewriting, possibly involving recursive functions ret = rewriteNode(ret); diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index 80411b258..a44ebd297 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -21,6 +21,7 @@ #include #include "expr/dtype.h" +#include "smt/env_obj.h" #include "theory/evaluator.h" #include "theory/quantifiers/extended_rewrite.h" #include "theory/quantifiers/fun_def_evaluator.h" @@ -53,9 +54,10 @@ enum EnumeratorRole std::ostream& operator<<(std::ostream& os, EnumeratorRole r); // TODO :issue #1235 split and document this class -class TermDbSygus { +class TermDbSygus : protected EnvObj +{ public: - TermDbSygus(QuantifiersState& qs); + TermDbSygus(Env& env, QuantifiersState& qs); ~TermDbSygus() {} /** Finish init, which sets the inference manager */ void finishInit(QuantifiersInferenceManager* qim); @@ -78,8 +80,6 @@ class TermDbSygus { //------------------------------utilities /** get the explanation utility */ SygusExplain* getExplain() { return d_syexp.get(); } - /** get the extended rewrite utility */ - ExtendedRewriter* getExtRewriter() { return d_ext_rw.get(); } /** get the evaluator */ Evaluator* getEvaluator() { return d_eval.get(); } /** (recursive) function evaluator utility */ @@ -324,8 +324,6 @@ class TermDbSygus { //------------------------------utilities /** sygus explanation */ std::unique_ptr d_syexp; - /** extended rewriter */ - std::unique_ptr d_ext_rw; /** evaluator */ std::unique_ptr d_eval; /** (recursive) function evaluator utility */ @@ -461,7 +459,6 @@ class TermDbSygus { /** get anchor */ static Node getAnchor( Node n ); static unsigned getAnchorDepth( Node n ); - }; } // namespace quantifiers diff --git a/src/theory/quantifiers/term_registry.cpp b/src/theory/quantifiers/term_registry.cpp index 324217798..36dc8865c 100644 --- a/src/theory/quantifiers/term_registry.cpp +++ b/src/theory/quantifiers/term_registry.cpp @@ -29,7 +29,9 @@ namespace cvc5 { namespace theory { namespace quantifiers { -TermRegistry::TermRegistry(QuantifiersState& qs, QuantifiersRegistry& qr) +TermRegistry::TermRegistry(Env& env, + QuantifiersState& qs, + QuantifiersRegistry& qr) : d_presolve(qs.getUserContext(), true), d_presolveCache(qs.getUserContext()), d_termEnum(new TermEnumeration), @@ -42,7 +44,7 @@ TermRegistry::TermRegistry(QuantifiersState& qs, QuantifiersRegistry& qr) if (options::sygus() || options::sygusInst()) { // must be constructed here since it is required for datatypes finistInit - d_sygusTdb.reset(new TermDbSygus(qs)); + d_sygusTdb.reset(new TermDbSygus(env, qs)); } Trace("quant-engine-debug") << "Initialize quantifiers engine." << std::endl; Trace("quant-engine-debug") diff --git a/src/theory/quantifiers/term_registry.h b/src/theory/quantifiers/term_registry.h index c3e4fcf4c..e0ce73286 100644 --- a/src/theory/quantifiers/term_registry.h +++ b/src/theory/quantifiers/term_registry.h @@ -42,8 +42,7 @@ class TermRegistry using NodeSet = context::CDHashSet; public: - TermRegistry(QuantifiersState& qs, - QuantifiersRegistry& qr); + TermRegistry(Env& env, QuantifiersState& qs, QuantifiersRegistry& qr); /** Finish init, which sets the inference manager on modules of this class */ void finishInit(FirstOrderModel* fm, QuantifiersInferenceManager* qim); /** Presolve */ diff --git a/src/theory/quantifiers/theory_quantifiers.cpp b/src/theory/quantifiers/theory_quantifiers.cpp index dff0ac979..137e25c89 100644 --- a/src/theory/quantifiers/theory_quantifiers.cpp +++ b/src/theory/quantifiers/theory_quantifiers.cpp @@ -36,13 +36,13 @@ TheoryQuantifiers::TheoryQuantifiers(Env& env, : Theory(THEORY_QUANTIFIERS, env, out, valuation), d_qstate(env, valuation, logicInfo()), d_qreg(), - d_treg(d_qstate, d_qreg), + d_treg(env, d_qstate, d_qreg), d_qim(env, *this, d_qstate, d_qreg, d_treg, d_pnm), d_qengine(nullptr) { // construct the quantifiers engine d_qengine.reset( - new QuantifiersEngine(d_qstate, d_qreg, d_treg, d_qim, d_pnm)); + new QuantifiersEngine(env, d_qstate, d_qreg, d_treg, d_qim, d_pnm)); // indicate we are using the quantifiers theory state object d_theoryState = &d_qstate; -- cgit v1.2.3