diff options
44 files changed, 277 insertions, 180 deletions
diff --git a/src/smt/env.cpp b/src/smt/env.cpp index f42a51dd0..c77b8cfba 100644 --- a/src/smt/env.cpp +++ b/src/smt/env.cpp @@ -24,6 +24,7 @@ #include "proof/conv_proof_generator.h" #include "smt/dump_manager.h" #include "smt/smt_engine_stats.h" +#include "theory/evaluator.h" #include "theory/rewriter.h" #include "theory/trust_substitutions.h" #include "util/resource_manager.h" @@ -39,6 +40,8 @@ Env::Env(NodeManager* nm, const Options* opts) d_nodeManager(nm), d_proofNodeManager(nullptr), d_rewriter(new theory::Rewriter()), + d_evalRew(new theory::Evaluator(d_rewriter.get())), + d_eval(new theory::Evaluator(nullptr)), d_topLevelSubs(new theory::TrustSubstitutionMap(d_userContext.get())), d_dumpManager(new DumpManager(d_userContext.get())), d_logic(), @@ -132,4 +135,55 @@ const Printer& Env::getPrinter() std::ostream& Env::getDumpOut() { return *d_options.base.out; } +Node Env::evaluate(TNode n, + const std::vector<Node>& args, + const std::vector<Node>& vals, + bool useRewriter) const +{ + std::unordered_map<Node, Node> visited; + return evaluate(n, args, vals, visited, useRewriter); +} + +Node Env::evaluate(TNode n, + const std::vector<Node>& args, + const std::vector<Node>& vals, + const std::unordered_map<Node, Node>& visited, + bool useRewriter) const +{ + if (useRewriter) + { + return d_evalRew->eval(n, args, vals, visited); + } + return d_eval->eval(n, args, vals, visited); +} + +Node Env::rewriteViaMethod(TNode n, MethodId idr) +{ + if (idr == MethodId::RW_REWRITE) + { + return d_rewriter->rewrite(n); + } + if (idr == MethodId::RW_EXT_REWRITE) + { + return d_rewriter->extendedRewrite(n); + } + if (idr == MethodId::RW_REWRITE_EQ_EXT) + { + return d_rewriter->rewriteEqualityExt(n); + } + if (idr == MethodId::RW_EVALUATE) + { + return evaluate(n, {}, {}, false); + } + if (idr == MethodId::RW_IDENTITY) + { + // does nothing + return n; + } + // unknown rewriter + Unhandled() << "Env::rewriteViaMethod: no rewriter for " << idr + << std::endl; + return n; +} + } // namespace cvc5 diff --git a/src/smt/env.h b/src/smt/env.h index d95e70226..2f2fe19ce 100644 --- a/src/smt/env.h +++ b/src/smt/env.h @@ -22,6 +22,7 @@ #include <memory> #include "options/options.h" +#include "proof/method_id.h" #include "theory/logic_info.h" #include "util/statistics_registry.h" @@ -44,6 +45,7 @@ class PfManager; } namespace theory { +class Evaluator; class Rewriter; class TrustSubstitutionMap; } @@ -137,6 +139,39 @@ class Env */ std::ostream& getDumpOut(); + /* Rewrite helpers--------------------------------------------------------- */ + /** + * Evaluate node n under the substitution args -> vals. For details, see + * theory/evaluator.h. + * + * @param n The node to evaluate + * @param args The domain of the substitution + * @param vals The range of the substitution + * @param useRewriter if true, we use this rewriter to rewrite subterms of + * n that cannot be evaluated to a constant. + * @return the rewritten, evaluated form of n under the given substitution. + */ + Node evaluate(TNode n, + const std::vector<Node>& args, + const std::vector<Node>& vals, + bool useRewriter) const; + /** Same as above, with a visited cache. */ + Node evaluate(TNode n, + const std::vector<Node>& args, + const std::vector<Node>& vals, + const std::unordered_map<Node, Node>& visited, + bool useRewriter = true) const; + /** + * Apply rewrite on n via the rewrite method identifier idr (see method_id.h). + * This encapsulates the exact behavior of a REWRITE step in a proof. + * + * @param n The node to rewrite, + * @param idr The method identifier of the rewriter, by default RW_REWRITE + * specifying a call to rewrite. + * @return The rewritten form of n. + */ + Node rewriteViaMethod(TNode n, MethodId idr = MethodId::RW_REWRITE); + private: /* Private initialization ------------------------------------------------- */ @@ -173,6 +208,10 @@ class Env * specific to an SmtEngine/TheoryEngine instance. */ std::unique_ptr<theory::Rewriter> d_rewriter; + /** Evaluator that also invokes the rewriter */ + std::unique_ptr<theory::Evaluator> d_evalRew; + /** Evaluator that does not invoke the rewriter */ + std::unique_ptr<theory::Evaluator> d_eval; /** The top level substitutions */ std::unique_ptr<theory::TrustSubstitutionMap> d_topLevelSubs; /** The dump manager */ diff --git a/src/smt/env_obj.cpp b/src/smt/env_obj.cpp index fcbcc92d2..b9aebbe83 100644 --- a/src/smt/env_obj.cpp +++ b/src/smt/env_obj.cpp @@ -33,6 +33,21 @@ Node EnvObj::extendedRewrite(TNode node, bool aggr) const { return d_env.getRewriter()->extendedRewrite(node, aggr); } +Node EnvObj::evaluate(TNode n, + const std::vector<Node>& args, + const std::vector<Node>& vals, + bool useRewriter) const +{ + return d_env.evaluate(n, args, vals, useRewriter); +} +Node EnvObj::evaluate(TNode n, + const std::vector<Node>& args, + const std::vector<Node>& vals, + const std::unordered_map<Node, Node>& visited, + bool useRewriter) const +{ + return d_env.evaluate(n, args, vals, visited, useRewriter); +} const LogicInfo& EnvObj::logicInfo() const { return d_env.getLogicInfo(); } diff --git a/src/smt/env_obj.h b/src/smt/env_obj.h index ef9a82b17..75b97fda9 100644 --- a/src/smt/env_obj.h +++ b/src/smt/env_obj.h @@ -55,6 +55,20 @@ class EnvObj * This is a wrapper around theory::Rewriter::extendedRewrite via Env. */ Node extendedRewrite(TNode node, bool aggr = true) const; + /** + * Evaluate node n under the substitution args -> vals. + * This is a wrapper about theory::Rewriter::evaluate via Env. + */ + Node evaluate(TNode n, + const std::vector<Node>& args, + const std::vector<Node>& vals, + bool useRewriter = true) const; + /** Same as above, with a visited cache. */ + Node evaluate(TNode n, + const std::vector<Node>& args, + const std::vector<Node>& vals, + const std::unordered_map<Node, Node>& visited, + bool useRewriter = true) const; /** Get the current logic information. */ const LogicInfo& logicInfo() const; diff --git a/src/smt/proof_post_processor.cpp b/src/smt/proof_post_processor.cpp index 3866c1e0e..f5db349e1 100644 --- a/src/smt/proof_post_processor.cpp +++ b/src/smt/proof_post_processor.cpp @@ -478,8 +478,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, rargs.push_back(args[3]); } } - Rewriter* rr = d_env.getRewriter(); - Node tr = rr->rewriteViaMethod(ts, idr); + Node tr = d_env.rewriteViaMethod(ts, idr); Trace("smt-proof-pp-debug") << "...eq intro rewrite equality is " << ts << " == " << tr << ", from " << idr << std::endl; @@ -954,7 +953,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, getMethodId(args[1], idr); } Rewriter* rr = d_env.getRewriter(); - Node ret = rr->rewriteViaMethod(args[0], idr); + Node ret = d_env.rewriteViaMethod(args[0], idr); Node eq = args[0].eqNode(ret); if (idr == MethodId::RW_REWRITE || idr == MethodId::RW_REWRITE_EQ_EXT) { diff --git a/src/theory/builtin/proof_checker.cpp b/src/theory/builtin/proof_checker.cpp index d71b3635b..1309a05f9 100644 --- a/src/theory/builtin/proof_checker.cpp +++ b/src/theory/builtin/proof_checker.cpp @@ -18,10 +18,10 @@ #include "expr/skolem_manager.h" #include "smt/env.h" #include "smt/term_formula_removal.h" -#include "theory/evaluator.h" #include "theory/rewriter.h" #include "theory/substitutions.h" #include "theory/theory.h" +#include "util/rational.h" using namespace cvc5::kind; @@ -67,7 +67,7 @@ Node BuiltinProofRuleChecker::applySubstitutionRewrite( MethodId idr) { Node nks = applySubstitution(n, exp, ids, ida); - return d_env.getRewriter()->rewriteViaMethod(nks, idr); + return d_env.rewriteViaMethod(nks, idr); } bool BuiltinProofRuleChecker::getSubstitutionForLit(Node exp, @@ -249,7 +249,7 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, { return Node::null(); } - Node res = d_env.getRewriter()->rewriteViaMethod(args[0], idr); + Node res = d_env.rewriteViaMethod(args[0], idr); if (res.isNull()) { return Node::null(); @@ -260,7 +260,7 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, { Assert(children.empty()); Assert(args.size() == 1); - Node res = d_env.getRewriter()->rewriteViaMethod(args[0], MethodId::RW_EVALUATE); + Node res = d_env.rewriteViaMethod(args[0], MethodId::RW_EVALUATE); if (res.isNull()) { return Node::null(); @@ -302,7 +302,7 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, << SkolemManager::getOriginalForm(res) << std::endl; // **** NOTE: can rewrite the witness form here. This enables certain lemmas // to be provable, e.g. (= k t) where k is a purification Skolem for t. - res = Rewriter::rewrite(SkolemManager::getOriginalForm(res)); + res = d_env.getRewriter()->rewrite(SkolemManager::getOriginalForm(res)); if (!res.isConst() || !res.getConst<bool>()) { Trace("builtin-pfcheck") @@ -349,8 +349,8 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, if (res1 != res2) { // can rewrite the witness forms - res1 = Rewriter::rewrite(SkolemManager::getOriginalForm(res1)); - res2 = Rewriter::rewrite(SkolemManager::getOriginalForm(res2)); + res1 = d_env.getRewriter()->rewrite(SkolemManager::getOriginalForm(res1)); + res2 = d_env.getRewriter()->rewrite(SkolemManager::getOriginalForm(res2)); if (res1.isNull() || res1 != res2) { Trace("builtin-pfcheck") << "Failed to match results" << std::endl; diff --git a/src/theory/datatypes/sygus_datatype_utils.cpp b/src/theory/datatypes/sygus_datatype_utils.cpp index f1f7b45a4..12c255f57 100644 --- a/src/theory/datatypes/sygus_datatype_utils.cpp +++ b/src/theory/datatypes/sygus_datatype_utils.cpp @@ -391,7 +391,7 @@ Node sygusToBuiltin(Node n, bool isExternal) Node sygusToBuiltinEval(Node n, const std::vector<Node>& args) { NodeManager* nm = NodeManager::currentNM(); - Evaluator eval; + Evaluator eval(nullptr); // constant arguments? bool constArgs = true; for (const Node& a : args) diff --git a/src/theory/datatypes/sygus_extension.cpp b/src/theory/datatypes/sygus_extension.cpp index 2411013b2..d666cdac5 100644 --- a/src/theory/datatypes/sygus_extension.cpp +++ b/src/theory/datatypes/sygus_extension.cpp @@ -35,6 +35,7 @@ #include "theory/rewriter.h" #include "theory/theory_model.h" #include "theory/theory_state.h" +#include "util/rational.h" using namespace cvc5; using namespace cvc5::kind; @@ -1101,16 +1102,20 @@ Node SygusExtension::registerSearchValue(Node a, if (bv != bvr) { // add to the sampler database object - std::map<TypeNode, quantifiers::SygusSampler>::iterator its = - d_sampler[a].find(tn); - if (its == d_sampler[a].end()) + std::map<TypeNode, std::unique_ptr<quantifiers::SygusSampler>>& smap = + d_sampler[a]; + std::map<TypeNode, + std::unique_ptr<quantifiers::SygusSampler>>::iterator its = + smap.find(tn); + if (its == smap.end()) { - d_sampler[a][tn].initializeSygus( + smap[tn].reset(new quantifiers::SygusSampler(d_env)); + smap[tn]->initializeSygus( d_tds, nv, options::sygusSamples(), false); its = d_sampler[a].find(tn); } // check equivalent - its->second.checkEquivalent(bv, bvr, *options().base.out); + its->second->checkEquivalent(bv, bvr, *options().base.out); } } diff --git a/src/theory/datatypes/sygus_extension.h b/src/theory/datatypes/sygus_extension.h index c7a9e7893..2fd0110b4 100644 --- a/src/theory/datatypes/sygus_extension.h +++ b/src/theory/datatypes/sygus_extension.h @@ -289,7 +289,8 @@ private: * This is used for the sygusRewVerify() option to verify the correctness of * the rewriter. */ - std::map<Node, std::map<TypeNode, quantifiers::SygusSampler>> d_sampler; + std::map<Node, std::map<TypeNode, std::unique_ptr<quantifiers::SygusSampler>>> + d_sampler; /** Assert tester internal * * This function is called when the tester with index tindex is asserted for diff --git a/src/theory/datatypes/sygus_simple_sym.cpp b/src/theory/datatypes/sygus_simple_sym.cpp index 36dfc710b..63e60a478 100644 --- a/src/theory/datatypes/sygus_simple_sym.cpp +++ b/src/theory/datatypes/sygus_simple_sym.cpp @@ -17,6 +17,7 @@ #include "expr/dtype_cons.h" #include "theory/quantifiers/term_util.h" +#include "util/rational.h" using namespace std; using namespace cvc5::kind; diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 4a8976876..427e0251f 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -22,6 +22,7 @@ #include "expr/dtype_cons.h" #include "expr/kind.h" #include "expr/skolem_manager.h" +#include "expr/uninterpreted_constant.h" #include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "options/smt_options.h" @@ -38,6 +39,7 @@ #include "theory/theory_state.h" #include "theory/type_enumerator.h" #include "theory/valuation.h" +#include "util/rational.h" using namespace std; using namespace cvc5::kind; diff --git a/src/theory/evaluator.cpp b/src/theory/evaluator.cpp index 2a274426f..75c878065 100644 --- a/src/theory/evaluator.cpp +++ b/src/theory/evaluator.cpp @@ -127,19 +127,22 @@ Node EvalResult::toNode() const } } +Evaluator::Evaluator(Rewriter* rr) + : d_rr(rr), d_alphaCard(strings::utils::getAlphabetCardinality()) +{ +} + Node Evaluator::eval(TNode n, const std::vector<Node>& args, - const std::vector<Node>& vals, - bool useRewriter) const + const std::vector<Node>& vals) const { std::unordered_map<Node, Node> visited; - return eval(n, args, vals, visited, useRewriter); + return eval(n, args, vals, visited); } Node Evaluator::eval(TNode n, const std::vector<Node>& args, const std::vector<Node>& vals, - const std::unordered_map<Node, Node>& visited, - bool useRewriter) const + const std::unordered_map<Node, Node>& visited) const { Trace("evaluator") << "Evaluating " << n << " under substitution " << args << " " << vals << " with visited size = " << visited.size() @@ -150,36 +153,36 @@ Node Evaluator::eval(TNode n, for (const std::pair<const Node, Node>& p : visited) { Trace("evaluator") << "Add " << p.first << " == " << p.second << std::endl; - results[p.first] = evalInternal(p.second, args, vals, evalAsNode, results, useRewriter); + results[p.first] = evalInternal(p.second, args, vals, evalAsNode, results); if (results[p.first].d_tag == EvalResult::INVALID) { // could not evaluate, use the evalAsNode map std::unordered_map<TNode, Node>::iterator itn = evalAsNode.find(p.second); Assert(itn != evalAsNode.end()); Node val = itn->second; - if (useRewriter) + if (d_rr != nullptr) { - val = Rewriter::rewrite(val); + val = d_rr->rewrite(val); } evalAsNode[p.first] = val; } } Trace("evaluator") << "Run eval internal..." << std::endl; - Node ret = evalInternal(n, args, vals, evalAsNode, results, useRewriter).toNode(); + Node ret = evalInternal(n, args, vals, evalAsNode, results).toNode(); // if we failed to evaluate - if (ret.isNull() && useRewriter) + if (ret.isNull() && d_rr != nullptr) { // should be stored in the evaluation-as-node map std::unordered_map<TNode, Node>::iterator itn = evalAsNode.find(n); Assert(itn != evalAsNode.end()); - ret = Rewriter::rewrite(itn->second); + ret = d_rr->rewrite(itn->second); } // should be the same as substitution + rewriting, or possibly null if - // useRewriter is false - Assert((ret.isNull() && !useRewriter) + // d_rr is nullptr + Assert((ret.isNull() && d_rr == nullptr) || ret - == Rewriter::rewrite(n.substitute( - args.begin(), args.end(), vals.begin(), vals.end()))); + == d_rr->rewrite(n.substitute( + args.begin(), args.end(), vals.begin(), vals.end()))); return ret; } @@ -188,8 +191,7 @@ EvalResult Evaluator::evalInternal( const std::vector<Node>& args, const std::vector<Node>& vals, std::unordered_map<TNode, Node>& evalAsNode, - std::unordered_map<TNode, EvalResult>& results, - bool useRewriter) const + std::unordered_map<TNode, EvalResult>& results) const { std::vector<TNode> queue; queue.emplace_back(n); @@ -290,11 +292,11 @@ EvalResult Evaluator::evalInternal( // successfully evaluated, and the children that did not. Trace("evaluator") << "Evaluator: collect arguments" << std::endl; currNodeVal = reconstruct(currNodeVal, results, evalAsNode); - if (useRewriter) + if (d_rr != nullptr) { // Rewrite the result now, if we use the rewriter. We will see below // if we are able to turn it into a valid EvalResult. - currNodeVal = Rewriter::rewrite(currNodeVal); + currNodeVal = d_rr->rewrite(currNodeVal); } } needsReconstruct = false; @@ -360,12 +362,8 @@ EvalResult Evaluator::evalInternal( // evalAsNodeC but favor avoiding this copy for performance reasons. std::unordered_map<TNode, Node> evalAsNodeC; std::unordered_map<TNode, EvalResult> resultsC; - results[currNode] = evalInternal(op[1], - lambdaArgs, - lambdaVals, - evalAsNodeC, - resultsC, - useRewriter); + results[currNode] = evalInternal( + op[1], lambdaArgs, lambdaVals, evalAsNodeC, resultsC); Trace("evaluator") << "Evaluated via arguments to " << results[currNode].d_tag << std::endl; if (results[currNode].d_tag == EvalResult::INVALID) @@ -676,7 +674,7 @@ EvalResult Evaluator::evalInternal( case kind::STRING_FROM_CODE: { Integer i = results[currNode[0]].d_rat.getNumerator(); - if (i >= 0 && i < strings::utils::getAlphabetCardinality()) + if (i >= 0 && i < d_alphaCard) { std::vector<unsigned> svec = {i.toUnsignedInt()}; results[currNode] = EvalResult(String(svec)); diff --git a/src/theory/evaluator.h b/src/theory/evaluator.h index 42cc34749..2e96952b8 100644 --- a/src/theory/evaluator.h +++ b/src/theory/evaluator.h @@ -80,6 +80,8 @@ struct EvalResult Node toNode() const; }; +class Rewriter; + /** * The class that performs the actual evaluation of a term under a * substitution. Right now, the class does not cache anything between different @@ -88,6 +90,7 @@ struct EvalResult class Evaluator { public: + Evaluator(Rewriter* rr); /** * Evaluates node `n` under the substitution described by the variable names * `args` and the corresponding values `vals`. This method uses evaluation @@ -103,22 +106,20 @@ class Evaluator * The result of this call is either equivalent to: * (1) Rewriter::rewrite(n.substitute(args,vars)) * (2) Node::null(). - * If useRewriter is true, then we are always in the first case. If - * useRewriter is false, then we may be in case (2) if computing the + * If d_rr is non-null, then we are always in the first case. If + * useRewriter is null, then we may be in case (2) if computing the * rewritten, substituted form of n could not be determined by evaluation. */ Node eval(TNode n, const std::vector<Node>& args, - const std::vector<Node>& vals, - bool useRewriter = true) const; + const std::vector<Node>& vals) const; /** * Same as above, but with a precomputed visited map. */ Node eval(TNode n, const std::vector<Node>& args, const std::vector<Node>& vals, - const std::unordered_map<Node, Node>& visited, - bool useRewriter = true) const; + const std::unordered_map<Node, Node>& visited) const; private: /** @@ -141,8 +142,7 @@ class Evaluator const std::vector<Node>& args, const std::vector<Node>& vals, std::unordered_map<TNode, Node>& evalAsNode, - std::unordered_map<TNode, EvalResult>& results, - bool useRewriter) const; + std::unordered_map<TNode, EvalResult>& results) const; /** reconstruct * * This function reconstructs the result of evaluating n using a combination @@ -155,6 +155,10 @@ class Evaluator Node reconstruct(TNode n, std::unordered_map<TNode, EvalResult>& eresults, std::unordered_map<TNode, Node>& evalAsNode) const; + /** The (optional) rewriter to be used */ + Rewriter* d_rr; + /** The cardinality of the alphabet of strings */ + uint32_t d_alphaCard; }; } // namespace theory diff --git a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp index 88da629a0..4b06589b3 100644 --- a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp +++ b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp @@ -15,15 +15,14 @@ #include "theory/quantifiers/cegqi/ceg_instantiator.h" -#include "theory/quantifiers/cegqi/ceg_arith_instantiator.h" -#include "theory/quantifiers/cegqi/ceg_bv_instantiator.h" -#include "theory/quantifiers/cegqi/ceg_dt_instantiator.h" - #include "expr/dtype.h" #include "expr/dtype_cons.h" #include "expr/node_algorithm.h" #include "options/quantifiers_options.h" #include "theory/arith/arith_msum.h" +#include "theory/quantifiers/cegqi/ceg_arith_instantiator.h" +#include "theory/quantifiers/cegqi/ceg_bv_instantiator.h" +#include "theory/quantifiers/cegqi/ceg_dt_instantiator.h" #include "theory/quantifiers/cegqi/inst_strategy_cegqi.h" #include "theory/quantifiers/first_order_model.h" #include "theory/quantifiers/quantifiers_attributes.h" @@ -31,6 +30,7 @@ #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace std; using namespace cvc5::kind; diff --git a/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp b/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp index 1ccfd8ede..04fa1d2fe 100644 --- a/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp +++ b/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp @@ -25,6 +25,7 @@ #include "theory/quantifiers/term_registry.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace std; using namespace cvc5::kind; diff --git a/src/theory/quantifiers/ematching/inst_match_generator.cpp b/src/theory/quantifiers/ematching/inst_match_generator.cpp index 5380fc7d5..d8e3b7950 100644 --- a/src/theory/quantifiers/ematching/inst_match_generator.cpp +++ b/src/theory/quantifiers/ematching/inst_match_generator.cpp @@ -29,6 +29,7 @@ #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_registry.h" #include "theory/quantifiers/term_util.h" +#include "util/rational.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/expr_miner_manager.cpp b/src/theory/quantifiers/expr_miner_manager.cpp index 8af456ea8..e53fd9424 100644 --- a/src/theory/quantifiers/expr_miner_manager.cpp +++ b/src/theory/quantifiers/expr_miner_manager.cpp @@ -16,6 +16,7 @@ #include "theory/quantifiers/expr_miner_manager.h" #include "options/quantifiers_options.h" +#include "smt/env.h" namespace cvc5 { namespace theory { @@ -33,7 +34,8 @@ ExpressionMinerManager::ExpressionMinerManager(Env& env) options::sygusRewSynthAccel(), false), d_qg(env), - d_sols(env) + d_sols(env), + d_sampler(env) { } diff --git a/src/theory/quantifiers/fmf/bounded_integers.cpp b/src/theory/quantifiers/fmf/bounded_integers.cpp index 4a3e13dd0..44352c6fe 100644 --- a/src/theory/quantifiers/fmf/bounded_integers.cpp +++ b/src/theory/quantifiers/fmf/bounded_integers.cpp @@ -29,6 +29,7 @@ #include "theory/quantifiers/term_enumeration.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace cvc5; using namespace std; diff --git a/src/theory/quantifiers/fun_def_evaluator.cpp b/src/theory/quantifiers/fun_def_evaluator.cpp index 36f557db8..78a09641b 100644 --- a/src/theory/quantifiers/fun_def_evaluator.cpp +++ b/src/theory/quantifiers/fun_def_evaluator.cpp @@ -26,7 +26,7 @@ namespace cvc5 { namespace theory { namespace quantifiers { -FunDefEvaluator::FunDefEvaluator() {} +FunDefEvaluator::FunDefEvaluator(Env& env) : EnvObj(env) {} void FunDefEvaluator::assertDefinition(Node q) { @@ -51,11 +51,11 @@ void FunDefEvaluator::assertDefinition(Node q) << fdi.d_args << " / " << fdi.d_body << std::endl; } -Node FunDefEvaluator::evaluate(Node n) const +Node FunDefEvaluator::evaluateDefinitions(Node n) const { // should do standard rewrite before this call Assert(Rewriter::rewrite(n) == n); - Trace("fd-eval") << "FunDefEvaluator: evaluate " << n << std::endl; + Trace("fd-eval") << "FunDefEvaluator: evaluateDefinitions " << n << std::endl; NodeManager* nm = NodeManager::currentNM(); std::unordered_map<TNode, unsigned> funDefCount; std::unordered_map<TNode, unsigned>::iterator itCount; @@ -185,7 +185,7 @@ Node FunDefEvaluator::evaluate(Node n) const if (!args.empty()) { // invoke it on arguments using the evaluator - sbody = d_eval.eval(sbody, args, children); + sbody = evaluate(sbody, args, children); if (Trace.isOn("fd-eval-debug2")) { Trace("fd-eval-debug2") diff --git a/src/theory/quantifiers/fun_def_evaluator.h b/src/theory/quantifiers/fun_def_evaluator.h index a3b79bec7..c8e811968 100644 --- a/src/theory/quantifiers/fun_def_evaluator.h +++ b/src/theory/quantifiers/fun_def_evaluator.h @@ -20,8 +20,9 @@ #include <map> #include <vector> + #include "expr/node.h" -#include "theory/evaluator.h" +#include "smt/env_obj.h" namespace cvc5 { namespace theory { @@ -30,10 +31,10 @@ namespace quantifiers { /** * Techniques for evaluating recursively defined functions. */ -class FunDefEvaluator +class FunDefEvaluator : protected EnvObj { public: - FunDefEvaluator(); + FunDefEvaluator(Env& env); ~FunDefEvaluator() {} /** * Assert definition of a (recursive) function definition given by quantified @@ -45,7 +46,7 @@ class FunDefEvaluator * class. If n cannot be simplified to a constant, then this method returns * null. */ - Node evaluate(Node n) const; + Node evaluateDefinitions(Node n) const; /** * Has a call to assertDefinition been made? If this returns false, then * the evaluate method is the same as calling the rewriter, and returning @@ -74,8 +75,6 @@ class FunDefEvaluator std::map<Node, FunDefInfo> d_funDefMap; /** list of all definitions */ std::vector<Node> d_funDefs; - /** evaluator utility */ - Evaluator d_eval; }; } // namespace quantifiers diff --git a/src/theory/quantifiers/quant_bound_inference.cpp b/src/theory/quantifiers/quant_bound_inference.cpp index a78f66c51..83e48bf9c 100644 --- a/src/theory/quantifiers/quant_bound_inference.cpp +++ b/src/theory/quantifiers/quant_bound_inference.cpp @@ -17,6 +17,7 @@ #include "theory/quantifiers/fmf/bounded_integers.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/quant_conflict_find.cpp b/src/theory/quantifiers/quant_conflict_find.cpp index 1de60422f..b26b65018 100644 --- a/src/theory/quantifiers/quant_conflict_find.cpp +++ b/src/theory/quantifiers/quant_conflict_find.cpp @@ -28,6 +28,7 @@ #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace cvc5::kind; using namespace std; diff --git a/src/theory/quantifiers/quantifiers_attributes.cpp b/src/theory/quantifiers/quantifiers_attributes.cpp index deed1c761..1a0e03bfc 100644 --- a/src/theory/quantifiers/quantifiers_attributes.cpp +++ b/src/theory/quantifiers/quantifiers_attributes.cpp @@ -19,6 +19,8 @@ #include "theory/arith/arith_msum.h" #include "theory/quantifiers/sygus/synth_engine.h" #include "theory/quantifiers/term_util.h" +#include "util/rational.h" +#include "util/string.h" using namespace std; using namespace cvc5::kind; diff --git a/src/theory/quantifiers/skolemize.cpp b/src/theory/quantifiers/skolemize.cpp index bb0fa3899..a34547f45 100644 --- a/src/theory/quantifiers/skolemize.cpp +++ b/src/theory/quantifiers/skolemize.cpp @@ -28,6 +28,7 @@ #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" #include "theory/sort_inference.h" +#include "util/rational.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/solution_filter.cpp b/src/theory/quantifiers/solution_filter.cpp index 8844950c7..19bfcab66 100644 --- a/src/theory/quantifiers/solution_filter.cpp +++ b/src/theory/quantifiers/solution_filter.cpp @@ -19,6 +19,7 @@ #include "options/base_options.h" #include "options/quantifiers_options.h" +#include "smt/env.h" #include "util/random.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/sygus/cegis.cpp b/src/theory/quantifiers/sygus/cegis.cpp index f5774c761..d9e4b61af 100644 --- a/src/theory/quantifiers/sygus/cegis.cpp +++ b/src/theory/quantifiers/sygus/cegis.cpp @@ -39,6 +39,7 @@ Cegis::Cegis(Env& env, SynthConjecture* p) : SygusModule(env, qs, qim, tds, p), d_eval_unfold(tds->getEvalUnfold()), + d_cegis_sampler(env), d_usingSymCons(false) { } @@ -594,7 +595,6 @@ bool Cegis::checkRefinementEvalLemmas(const std::vector<Node>& vs, } } - Evaluator* eval = d_tds->getEvaluator(); for (unsigned r = 0; r < 2; r++) { std::unordered_set<Node>& rlemmas = @@ -603,7 +603,7 @@ bool Cegis::checkRefinementEvalLemmas(const std::vector<Node>& vs, { // We may have computed the evaluation of some function applications // via example-based symmetry breaking, stored in evalVisited. - Node lemcsu = eval->eval(lem, vs, ms, evalVisited); + Node lemcsu = evaluate(lem, vs, ms, evalVisited); if (lemcsu.isConst() && !lemcsu.getConst<bool>()) { return true; diff --git a/src/theory/quantifiers/sygus/cegis.h b/src/theory/quantifiers/sygus/cegis.h index d72805950..8e0fffdd1 100644 --- a/src/theory/quantifiers/sygus/cegis.h +++ b/src/theory/quantifiers/sygus/cegis.h @@ -28,6 +28,8 @@ namespace cvc5 { namespace theory { namespace quantifiers { +class SygusEvalUnfold; + /** Cegis * * The default sygus module for synthesis, counterexample-guided inductive diff --git a/src/theory/quantifiers/sygus/cegis_core_connective.cpp b/src/theory/quantifiers/sygus/cegis_core_connective.cpp index a42323227..b9066b079 100644 --- a/src/theory/quantifiers/sygus/cegis_core_connective.cpp +++ b/src/theory/quantifiers/sygus/cegis_core_connective.cpp @@ -311,7 +311,7 @@ bool CegisCoreConnective::constructSolution( Assert(candidates.size() == 1 && candidates[0] == d_candidate); TNode cval = candidate_values[0]; Node ets = d_eterm.substitute(d_candidate, cval); - Node etsr = Rewriter::rewrite(ets); + Node etsr = rewrite(ets); Trace("sygus-ccore-debug") << "...predicate is: " << etsr << std::endl; NodeManager* nm = NodeManager::currentNM(); for (unsigned d = 0; d < 2; d++) @@ -476,7 +476,7 @@ Node CegisCoreConnective::Component::getRefinementPt( visited.insert(id); Trace("sygus-ccore-ref") << "...eval " << std::endl; // check if it is true - Node en = p->evaluate(n, id, ctx); + Node en = p->evaluatePt(n, id, ctx); if (en.isConst() && en.getConst<bool>()) { ss = ctx; @@ -553,7 +553,7 @@ bool CegisCoreConnective::Component::addToAsserts(CegisCoreConnective* p, for (unsigned i = currIndex, psize = passerts.size(); i < psize; i++) { Node cn = passerts[i]; - Node cne = p->evaluate(cn, mvId, mvs); + Node cne = p->evaluatePt(cn, mvId, mvs); if (cne.isConst() && !cne.getConst<bool>()) { n = cn; @@ -635,9 +635,9 @@ Result CegisCoreConnective::checkSat(Node n, std::vector<Node>& mvs) const return r; } -Node CegisCoreConnective::evaluate(Node n, - Node id, - const std::vector<Node>& mvs) +Node CegisCoreConnective::evaluatePt(Node n, + Node id, + const std::vector<Node>& mvs) { Kind nk = n.getKind(); if (nk == AND || nk == OR) @@ -647,7 +647,7 @@ Node CegisCoreConnective::evaluate(Node n, // split AND/OR for (const Node& nc : n) { - Node enc = evaluate(nc, id, mvs); + Node enc = evaluatePt(nc, id, mvs); Assert(enc.isConst()); if (enc.getConst<bool>() == expRes) { @@ -666,12 +666,8 @@ Node CegisCoreConnective::evaluate(Node n, } } // use evaluator - Node cn = d_eval.eval(n, d_vars, mvs); - if (cn.isNull()) - { - cn = n.substitute(d_vars.begin(), d_vars.end(), mvs.begin(), mvs.end()); - cn = Rewriter::rewrite(cn); - } + Node cn = evaluate(n, d_vars, mvs); + Assert(!cn.isNull()); if (!id.isNull()) { ec[id] = cn; @@ -844,7 +840,7 @@ Node CegisCoreConnective::constructSolutionFromPool(Component& ccheck, mvs.clear(); getModel(*checkSol, mvs); // should evaluate to true - Node ean = evaluate(an, Node::null(), mvs); + Node ean = evaluatePt(an, Node::null(), mvs); Assert(ean.isConst() && ean.getConst<bool>()); Trace("sygus-ccore") << "--- Add refinement point " << mvs << std::endl; // In terms of Variant #2, this is the line: diff --git a/src/theory/quantifiers/sygus/cegis_core_connective.h b/src/theory/quantifiers/sygus/cegis_core_connective.h index 80ba6f26e..ebcd871aa 100644 --- a/src/theory/quantifiers/sygus/cegis_core_connective.h +++ b/src/theory/quantifiers/sygus/cegis_core_connective.h @@ -23,7 +23,6 @@ #include "expr/node.h" #include "expr/node_trie.h" #include "smt/env_obj.h" -#include "theory/evaluator.h" #include "theory/quantifiers/sygus/cegis.h" #include "util/result.h" @@ -365,11 +364,9 @@ class CegisCoreConnective : public Cegis * If id is non-null, then id is a unique identifier for mvs, and we cache * the result of n for this point. */ - Node evaluate(Node n, Node id, const std::vector<Node>& mvs); + Node evaluatePt(Node n, Node id, const std::vector<Node>& mvs); /** A cache of the above function */ std::unordered_map<Node, std::unordered_map<Node, Node>> d_eval_cache; - /** The evaluator utility used for the above function */ - Evaluator d_eval; //-----------------------------------end for evaluation /** Construct solution from pool diff --git a/src/theory/quantifiers/sygus/cegis_unif.cpp b/src/theory/quantifiers/sygus/cegis_unif.cpp index 6b260bb81..42306383b 100644 --- a/src/theory/quantifiers/sygus/cegis_unif.cpp +++ b/src/theory/quantifiers/sygus/cegis_unif.cpp @@ -23,6 +23,7 @@ #include "theory/quantifiers/sygus/sygus_unif_rl.h" #include "theory/quantifiers/sygus/synth_conjecture.h" #include "theory/quantifiers/sygus/term_database_sygus.h" +#include "util/rational.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/sygus/enum_value_manager.cpp b/src/theory/quantifiers/sygus/enum_value_manager.cpp index e7b3bbaa9..937537ce9 100644 --- a/src/theory/quantifiers/sygus/enum_value_manager.cpp +++ b/src/theory/quantifiers/sygus/enum_value_manager.cpp @@ -106,7 +106,7 @@ Node EnumValueManager::getEnumeratedValue(bool& activeIncomplete) std::ostream* out = nullptr; if (options::sygusRewVerify()) { - d_samplerRrV.reset(new SygusSampler); + d_samplerRrV.reset(new SygusSampler(d_env)); d_samplerRrV->initializeSygus( d_tds, e, options::sygusSamples(), false); // use the default output for the output of sygusRewVerify diff --git a/src/theory/quantifiers/sygus/rcons_type_info.cpp b/src/theory/quantifiers/sygus/rcons_type_info.cpp index 78f8d303c..72a8e6a56 100644 --- a/src/theory/quantifiers/sygus/rcons_type_info.cpp +++ b/src/theory/quantifiers/sygus/rcons_type_info.cpp @@ -16,8 +16,10 @@ #include "theory/quantifiers/sygus/rcons_type_info.h" #include "expr/skolem_manager.h" +#include "smt/env.h" #include "theory/datatypes/sygus_datatype_utils.h" #include "theory/quantifiers/sygus/rcons_obligation.h" +#include "theory/quantifiers/sygus_sampler.h" namespace cvc5 { namespace theory { @@ -37,8 +39,9 @@ void RConsTypeInfo::initialize(Env& env, d_crd.reset(new CandidateRewriteDatabase(env, true, false, true, false)); // since initial samples are not always useful for equivalence checks, set // their number to 0 - d_sygusSampler.initialize(stn, builtinVars, 0); - d_crd->initialize(builtinVars, &d_sygusSampler); + d_sygusSampler.reset(new SygusSampler(env)); + d_sygusSampler->initialize(stn, builtinVars, 0); + d_crd->initialize(builtinVars, d_sygusSampler.get()); } Node RConsTypeInfo::nextEnum() diff --git a/src/theory/quantifiers/sygus/rcons_type_info.h b/src/theory/quantifiers/sygus/rcons_type_info.h index 294454fe2..5f68993ad 100644 --- a/src/theory/quantifiers/sygus/rcons_type_info.h +++ b/src/theory/quantifiers/sygus/rcons_type_info.h @@ -20,7 +20,6 @@ #include "theory/quantifiers/candidate_rewrite_database.h" #include "theory/quantifiers/sygus/sygus_enumerator.h" -#include "theory/quantifiers/sygus_sampler.h" namespace cvc5 { namespace theory { @@ -28,6 +27,7 @@ namespace quantifiers { class RConsObligation; class CandidateRewriteDatabase; +class SygusSampler; /** * A utility class for Sygus Reconstruct datatype types (grammar non-terminals). @@ -93,7 +93,7 @@ class RConsTypeInfo /** Candidate rewrite database for this class' sygus datatype type */ std::unique_ptr<CandidateRewriteDatabase> d_crd; /** Sygus sampler needed for initializing the candidate rewrite database */ - SygusSampler d_sygusSampler; + std::unique_ptr<SygusSampler> d_sygusSampler; /** A map from a builtin term to its obligation. * * Each sygus datatype type has its own version of this map because it is diff --git a/src/theory/quantifiers/sygus/sygus_enumerator.cpp b/src/theory/quantifiers/sygus/sygus_enumerator.cpp index fca09c43d..959532d98 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator.cpp @@ -25,6 +25,7 @@ #include "theory/quantifiers/sygus/synth_engine.h" #include "theory/quantifiers/sygus/type_node_id_trie.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp index 7072b77e1..43c958ff9 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp @@ -32,6 +32,7 @@ #include "theory/rewriter.h" #include "theory/strings/word.h" #include "util/floatingpoint.h" +#include "util/string.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/sygus/sygus_unif_io.cpp b/src/theory/quantifiers/sygus/sygus_unif_io.cpp index 3fb80f917..e703569d9 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_io.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_io.cpp @@ -17,7 +17,6 @@ #include "options/quantifiers_options.h" #include "theory/datatypes/sygus_datatype_utils.h" -#include "theory/evaluator.h" #include "theory/quantifiers/sygus/example_infer.h" #include "theory/quantifiers/sygus/synth_conjecture.h" #include "theory/quantifiers/sygus/term_database_sygus.h" diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index 035db433e..2e528b213 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -55,8 +55,7 @@ TermDbSygus::TermDbSygus(Env& env, QuantifiersState& qs) : EnvObj(env), d_qstate(qs), d_syexp(new SygusExplain(this)), - d_eval(new Evaluator), - d_funDefEval(new FunDefEvaluator), + d_funDefEval(new FunDefEvaluator(env)), d_eval_unfold(new SygusEvalUnfold(this)) { d_true = NodeManager::currentNM()->mkConst( true ); @@ -759,7 +758,7 @@ Node TermDbSygus::rewriteNode(Node n) const { // If recursive functions are enabled, then we use the recursive function // evaluation utility. - Node fres = d_funDefEval->evaluate(res); + Node fres = d_funDefEval->evaluateDefinitions(res); if (!fres.isNull()) { return fres; @@ -996,7 +995,7 @@ Node TermDbSygus::evaluateBuiltin(TypeNode tn, // This may fail if there is a subterm of bn under the // substitution that is not constant, or if an operator in bn is not // supported by the evaluator - res = d_eval->eval(bn, varlist, args); + res = evaluate(bn, varlist, args); } if (res.isNull()) { diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index 7b05c70e4..59e0f4776 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -22,7 +22,6 @@ #include "expr/dtype.h" #include "smt/env_obj.h" -#include "theory/evaluator.h" #include "theory/quantifiers/extended_rewrite.h" #include "theory/quantifiers/fun_def_evaluator.h" #include "theory/quantifiers/sygus/sygus_eval_unfold.h" @@ -80,8 +79,6 @@ class TermDbSygus : protected EnvObj //------------------------------utilities /** get the explanation utility */ SygusExplain* getExplain() { return d_syexp.get(); } - /** get the evaluator */ - Evaluator* getEvaluator() { return d_eval.get(); } /** (recursive) function evaluator utility */ FunDefEvaluator* getFunDefEvaluator() { return d_funDefEval.get(); } /** evaluation unfolding utility */ @@ -309,8 +306,6 @@ class TermDbSygus : protected EnvObj //------------------------------utilities /** sygus explanation */ std::unique_ptr<SygusExplain> d_syexp; - /** evaluator */ - std::unique_ptr<Evaluator> d_eval; /** (recursive) function evaluator utility */ std::unique_ptr<FunDefEvaluator> d_funDefEval; /** evaluation function unfolding utility */ diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp index 0cbc4df5b..08fab59eb 100644 --- a/src/theory/quantifiers/sygus_sampler.cpp +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -24,17 +24,20 @@ #include "options/quantifiers_options.h" #include "printer/printer.h" #include "theory/quantifiers/lazy_trie.h" +#include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/rewriter.h" #include "util/bitvector.h" #include "util/random.h" +#include "util/rational.h" #include "util/sampler.h" +#include "util/string.h" namespace cvc5 { namespace theory { namespace quantifiers { -SygusSampler::SygusSampler() - : d_tds(nullptr), d_use_sygus_type(false), d_is_valid(false) +SygusSampler::SygusSampler(Env& env) + : d_env(env), d_tds(nullptr), d_use_sygus_type(false), d_is_valid(false) { } @@ -471,21 +474,11 @@ Node SygusSampler::evaluate(Node n, unsigned index) { Assert(index < d_samples.size()); // do beta-reductions in n first - n = Rewriter::rewrite(n); + n = d_env.getRewriter()->rewrite(n); // use efficient rewrite for substitution + rewrite - Node ev = d_eval.eval(n, d_vars, d_samples[index]); + Node ev = d_env.evaluate(n, d_vars, d_samples[index], true); + Assert(!ev.isNull()); Trace("sygus-sample-ev") << "Evaluate ( " << n << ", " << index << " ) -> "; - if (!ev.isNull()) - { - Trace("sygus-sample-ev") << ev << std::endl; - return ev; - } - Trace("sygus-sample-ev") << "null" << std::endl; - Trace("sygus-sample-ev") << "Rewrite -> "; - // substitution + rewrite - std::vector<Node>& pt = d_samples[index]; - ev = n.substitute(d_vars.begin(), d_vars.end(), pt.begin(), pt.end()); - ev = Rewriter::rewrite(ev); Trace("sygus-sample-ev") << ev << std::endl; return ev; } @@ -617,7 +610,7 @@ Node SygusSampler::getRandomValue(TypeNode tn) // negative ret = nm->mkNode(kind::UMINUS, ret); } - ret = Rewriter::rewrite(ret); + ret = d_env.getRewriter()->rewrite(ret); Assert(ret.isConst()); return ret; } @@ -715,7 +708,7 @@ Node SygusSampler::getSygusRandomValue(TypeNode tn, Trace("sygus-sample-grammar") << "mkGeneric" << std::endl; Node ret = d_tds->mkGeneric(dt, cindex, pre); Trace("sygus-sample-grammar") << "...returned " << ret << std::endl; - ret = Rewriter::rewrite(ret); + ret = d_env.getRewriter()->rewrite(ret); Trace("sygus-sample-grammar") << "...after rewrite " << ret << std::endl; // A rare case where we generate a non-constant value from constant // leaves is (/ n 0). diff --git a/src/theory/quantifiers/sygus_sampler.h b/src/theory/quantifiers/sygus_sampler.h index 85606adc6..3695270e1 100644 --- a/src/theory/quantifiers/sygus_sampler.h +++ b/src/theory/quantifiers/sygus_sampler.h @@ -19,15 +19,18 @@ #define CVC5__THEORY__QUANTIFIERS__SYGUS_SAMPLER_H #include <map> -#include "theory/evaluator.h" #include "theory/quantifiers/lazy_trie.h" -#include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_enumeration.h" namespace cvc5 { + +class Env; + namespace theory { namespace quantifiers { +class TermDbSygus; + /** SygusSampler * * This class can be used to test whether two expressions are equivalent @@ -65,7 +68,7 @@ namespace quantifiers { class SygusSampler : public LazyTrieEvaluator { public: - SygusSampler(); + SygusSampler(Env& env); ~SygusSampler() override {} /** initialize @@ -178,14 +181,14 @@ class SygusSampler : public LazyTrieEvaluator void checkEquivalent(Node bv, Node bvr, std::ostream& out); protected: + /** The environment we are using to evaluate terms and samples */ + Env& d_env; /** sygus term database of d_qe */ TermDbSygus* d_tds; /** term enumerator object (used for random sampling) */ TermEnumeration d_tenum; /** samples */ std::vector<std::vector<Node> > d_samples; - /** evaluator class */ - Evaluator d_eval; /** data structure to check duplication of sample points */ class PtTrie { diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index 460813084..4e571a66b 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -479,35 +479,5 @@ void Rewriter::clearCaches() clearCachesInternal(); } -Node Rewriter::rewriteViaMethod(TNode n, MethodId idr) -{ - if (idr == MethodId::RW_REWRITE) - { - return rewrite(n); - } - if (idr == MethodId::RW_EXT_REWRITE) - { - return extendedRewrite(n); - } - if (idr == MethodId::RW_REWRITE_EQ_EXT) - { - return rewriteEqualityExt(n); - } - if (idr == MethodId::RW_EVALUATE) - { - Evaluator eval; - return eval.eval(n, {}, {}, false); - } - if (idr == MethodId::RW_IDENTITY) - { - // does nothing - return n; - } - // unknown rewriter - Unhandled() << "Rewriter::rewriteViaMethod: no rewriter for " << idr - << std::endl; - return n; -} - } // namespace theory } // namespace cvc5 diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h index d87043a67..697253e03 100644 --- a/src/theory/rewriter.h +++ b/src/theory/rewriter.h @@ -18,22 +18,24 @@ #pragma once #include "expr/node.h" -#include "proof/method_id.h" #include "theory/theory_rewriter.h" namespace cvc5 { +class Env; class TConvProofGenerator; class ProofNodeManager; class TrustNode; namespace theory { +class Evaluator; + /** * The main rewriter class. */ class Rewriter { - + friend class cvc5::Env; // to initialize the evaluators of this class public: Rewriter(); @@ -62,6 +64,9 @@ class Rewriter { Node rewriteEqualityExt(TNode node); /** + * !!! Temporary until static access to rewriter is eliminated. This method + * should be moved to same place as evaluate (currently in Env). + * * Extended rewrite of the given node. This method is implemented by a * custom ExtendRewriter class that wraps this class to perform custom * rewrites (usually those that are not useful for solving, but e.g. useful @@ -103,17 +108,6 @@ class Rewriter { /** Get the theory rewriter for the given id */ TheoryRewriter* getTheoryRewriter(theory::TheoryId theoryId); - /** - * Apply rewrite on n via the rewrite method identifier idr (see method_id.h). - * This encapsulates the exact behavior of a REWRITE step in a proof. - * - * @param n The node to rewrite, - * @param idr The method identifier of the rewriter, by default RW_REWRITE - * specifying a call to rewrite. - * @return The rewritten form of n. - */ - Node rewriteViaMethod(TNode n, MethodId idr = MethodId::RW_REWRITE); - private: /** * Get the rewriter associated with the SmtEngine in scope. diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h index 36d320fb7..e86d748fd 100644 --- a/src/theory/rewriter_tables_template.h +++ b/src/theory/rewriter_tables_template.h @@ -80,10 +80,7 @@ ${post_rewrite_set_cache} } } -Rewriter::Rewriter() : d_tpg(nullptr) -{ - -} +Rewriter::Rewriter() : d_tpg(nullptr) {} void Rewriter::clearCachesInternal() { diff --git a/test/unit/theory/evaluator_white.cpp b/test/unit/theory/evaluator_white.cpp index a1f56eaba..c2c6cf77e 100644 --- a/test/unit/theory/evaluator_white.cpp +++ b/test/unit/theory/evaluator_white.cpp @@ -59,10 +59,11 @@ TEST_F(TestTheoryWhiteEvaluator, simple) std::vector<Node> args = {w, x, y, z}; std::vector<Node> vals = {c1, zero, one, c1}; - Evaluator eval; + Rewriter* rr = d_smtEngine->getRewriter(); + Evaluator eval(rr); Node r = eval.eval(t, args, vals); ASSERT_EQ(r, - Rewriter::rewrite(t.substitute( + rr->rewrite(t.substitute( args.begin(), args.end(), vals.begin(), vals.end()))); } @@ -90,10 +91,11 @@ TEST_F(TestTheoryWhiteEvaluator, loop) std::vector<Node> args = {x}; std::vector<Node> vals = {c}; - Evaluator eval; + Rewriter* rr = d_smtEngine->getRewriter(); + Evaluator eval(rr); Node r = eval.eval(t, args, vals); ASSERT_EQ(r, - Rewriter::rewrite(t.substitute( + rr->rewrite(t.substitute( args.begin(), args.end(), vals.begin(), vals.end()))); } @@ -106,30 +108,31 @@ TEST_F(TestTheoryWhiteEvaluator, strIdOf) std::vector<Node> args; std::vector<Node> vals; - Evaluator eval; + Rewriter* rr = d_smtEngine->getRewriter(); + Evaluator eval(rr); { Node n = d_nodeManager->mkNode(kind::STRING_INDEXOF, a, empty, one); Node r = eval.eval(n, args, vals); - ASSERT_EQ(r, Rewriter::rewrite(n)); + ASSERT_EQ(r, rr->rewrite(n)); } { Node n = d_nodeManager->mkNode(kind::STRING_INDEXOF, a, a, one); Node r = eval.eval(n, args, vals); - ASSERT_EQ(r, Rewriter::rewrite(n)); + ASSERT_EQ(r, rr->rewrite(n)); } { Node n = d_nodeManager->mkNode(kind::STRING_INDEXOF, a, empty, two); Node r = eval.eval(n, args, vals); - ASSERT_EQ(r, Rewriter::rewrite(n)); + ASSERT_EQ(r, rr->rewrite(n)); } { Node n = d_nodeManager->mkNode(kind::STRING_INDEXOF, a, a, two); Node r = eval.eval(n, args, vals); - ASSERT_EQ(r, Rewriter::rewrite(n)); + ASSERT_EQ(r, rr->rewrite(n)); } } @@ -140,7 +143,8 @@ TEST_F(TestTheoryWhiteEvaluator, code) std::vector<Node> args; std::vector<Node> vals; - Evaluator eval; + Rewriter* rr = d_smtEngine->getRewriter(); + Evaluator eval(rr); // (str.code "A") ---> 65 { |