diff options
22 files changed, 653 insertions, 470 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 692ae09ac..971648839 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -309,6 +309,8 @@ libcvc4_add_sources( theory/arith/nl/cad/projections.h theory/arith/nl/cad/variable_ordering.cpp theory/arith/nl/cad/variable_ordering.h + theory/arith/nl/ext_theory_callback.cpp + theory/arith/nl/ext_theory_callback.h theory/arith/nl/iand_solver.cpp theory/arith/nl/iand_solver.h theory/arith/nl/inference.cpp diff --git a/src/theory/arith/nl/ext_theory_callback.cpp b/src/theory/arith/nl/ext_theory_callback.cpp new file mode 100644 index 000000000..4518df0de --- /dev/null +++ b/src/theory/arith/nl/ext_theory_callback.cpp @@ -0,0 +1,131 @@ +/********************* */ +/*! \file ext_theory_callback.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The extended theory callback for non-linear arithmetic + **/ + +#include "theory/arith/nl/ext_theory_callback.h" + +#include "theory/arith/arith_utilities.h" + +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { +namespace arith { +namespace nl { + +NlExtTheoryCallback::NlExtTheoryCallback(eq::EqualityEngine* ee) : d_ee(ee) +{ + d_zero = NodeManager::currentNM()->mkConst(Rational(0)); +} + +bool NlExtTheoryCallback::getCurrentSubstitution( + int effort, + const std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node>>& exp) +{ + // get the constant equivalence classes + std::map<Node, std::vector<int>> rep_to_subs_index; + + bool retVal = false; + for (unsigned i = 0; i < vars.size(); i++) + { + Node n = vars[i]; + if (d_ee->hasTerm(n)) + { + Node nr = d_ee->getRepresentative(n); + if (nr.isConst()) + { + subs.push_back(nr); + Trace("nl-subs") << "Basic substitution : " << n << " -> " << nr + << std::endl; + exp[n].push_back(n.eqNode(nr)); + retVal = true; + } + else + { + rep_to_subs_index[nr].push_back(i); + subs.push_back(n); + } + } + else + { + subs.push_back(n); + } + } + + // return true if the substitution is non-trivial + return retVal; +} + +bool NlExtTheoryCallback::isExtfReduced(int effort, + Node n, + Node on, + std::vector<Node>& exp) +{ + if (n != d_zero) + { + Kind k = n.getKind(); + return k != NONLINEAR_MULT && !isTranscendentalKind(k) && k != IAND; + } + Assert(n == d_zero); + if (on.getKind() == NONLINEAR_MULT) + { + Trace("nl-ext-zero-exp") + << "Infer zero : " << on << " == " << n << std::endl; + // minimize explanation if a substitution+rewrite results in zero + const std::set<Node> vars(on.begin(), on.end()); + + for (unsigned i = 0, size = exp.size(); i < size; i++) + { + Trace("nl-ext-zero-exp") + << " exp[" << i << "] = " << exp[i] << std::endl; + std::vector<Node> eqs; + if (exp[i].getKind() == EQUAL) + { + eqs.push_back(exp[i]); + } + else if (exp[i].getKind() == AND) + { + for (const Node& ec : exp[i]) + { + if (ec.getKind() == EQUAL) + { + eqs.push_back(ec); + } + } + } + + for (unsigned j = 0; j < eqs.size(); j++) + { + for (unsigned r = 0; r < 2; r++) + { + if (eqs[j][r] == d_zero && vars.find(eqs[j][1 - r]) != vars.end()) + { + Trace("nl-ext-zero-exp") + << "...single exp : " << eqs[j] << std::endl; + exp.clear(); + exp.push_back(eqs[j]); + return true; + } + } + } + } + } + return true; +} + +} // namespace nl +} // namespace arith +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/arith/nl/ext_theory_callback.h b/src/theory/arith/nl/ext_theory_callback.h new file mode 100644 index 000000000..0d95db166 --- /dev/null +++ b/src/theory/arith/nl/ext_theory_callback.h @@ -0,0 +1,86 @@ +/********************* */ +/*! \file ext_theory_callback.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The extended theory callback for non-linear arithmetic + **/ + +#ifndef CVC4__THEORY__ARITH__NL__EXT_THEORY_CALLBACK_H +#define CVC4__THEORY__ARITH__NL__EXT_THEORY_CALLBACK_H + +#include "expr/node.h" +#include "theory/ext_theory.h" + +namespace CVC4 { +namespace theory { +namespace arith { +namespace nl { + +class NlExtTheoryCallback : public ExtTheoryCallback +{ + public: + NlExtTheoryCallback(eq::EqualityEngine* ee); + ~NlExtTheoryCallback() {} + /** Get current substitution + * + * This function and the one below are + * used for context-dependent + * simplification, see Section 3.1 of + * "Designing Theory Solvers with Extensions" + * by Reynolds et al. FroCoS 2017. + * + * effort : an identifier indicating the stage where + * we are performing context-dependent simplification, + * vars : a set of arithmetic variables. + * + * This function populates subs and exp, such that for 0 <= i < vars.size(): + * ( exp[vars[i]] ) => vars[i] = subs[i] + * where exp[vars[i]] is a set of assertions + * that hold in the current context. We call { vars -> subs } a "derivable + * substituion" (see Reynolds et al. FroCoS 2017). + */ + bool getCurrentSubstitution(int effort, + const std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node>>& exp) override; + /** Is the term n in reduced form? + * + * Used for context-dependent simplification. + * + * effort : an identifier indicating the stage where + * we are performing context-dependent simplification, + * on : the original term that we reduced to n, + * exp : an explanation such that ( exp => on = n ). + * + * We return a pair ( b, exp' ) such that + * if b is true, then: + * n is in reduced form + * if exp' is non-null, then ( exp' => on = n ) + * The second part of the pair is used for constructing + * minimal explanations for context-dependent simplifications. + */ + bool isExtfReduced(int effort, + Node n, + Node on, + std::vector<Node>& exp) override; + + private: + /** The underlying equality engine. */ + eq::EqualityEngine* d_ee; + /** Commonly used nodes */ + Node d_zero; +}; + +} // namespace nl +} // namespace arith +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__ARITH__NL__EXT_THEORY_CALLBACK_H */ diff --git a/src/theory/arith/nl/nonlinear_extension.cpp b/src/theory/arith/nl/nonlinear_extension.cpp index ada6aa11a..733912969 100644 --- a/src/theory/arith/nl/nonlinear_extension.cpp +++ b/src/theory/arith/nl/nonlinear_extension.cpp @@ -39,7 +39,11 @@ NonlinearExtension::NonlinearExtension(TheoryArith& containing, d_ee(ee), d_needsLastCall(false), d_checkCounter(0), - d_extTheory(&containing), + d_extTheoryCb(ee), + d_extTheory(d_extTheoryCb, + containing.getSatContext(), + containing.getUserContext(), + containing.getOutputChannel()), d_model(containing.getSatContext()), d_trSlv(d_model), d_nlSlv(containing, d_model), @@ -67,101 +71,6 @@ void NonlinearExtension::preRegisterTerm(TNode n) d_extTheory.registerTermRec(n); } -bool NonlinearExtension::getCurrentSubstitution( - int effort, - const std::vector<Node>& vars, - std::vector<Node>& subs, - std::map<Node, std::vector<Node>>& exp) -{ - // get the constant equivalence classes - std::map<Node, std::vector<int>> rep_to_subs_index; - - bool retVal = false; - for (unsigned i = 0; i < vars.size(); i++) - { - Node n = vars[i]; - if (d_ee->hasTerm(n)) - { - Node nr = d_ee->getRepresentative(n); - if (nr.isConst()) - { - subs.push_back(nr); - Trace("nl-subs") << "Basic substitution : " << n << " -> " << nr - << std::endl; - exp[n].push_back(n.eqNode(nr)); - retVal = true; - } - else - { - rep_to_subs_index[nr].push_back(i); - subs.push_back(n); - } - } - else - { - subs.push_back(n); - } - } - - // return true if the substitution is non-trivial - return retVal; -} - -std::pair<bool, Node> NonlinearExtension::isExtfReduced( - int effort, Node n, Node on, const std::vector<Node>& exp) const -{ - if (n != d_zero) - { - Kind k = n.getKind(); - return std::make_pair( - k != NONLINEAR_MULT && !isTranscendentalKind(k) && k != IAND, - Node::null()); - } - Assert(n == d_zero); - if (on.getKind() == NONLINEAR_MULT) - { - Trace("nl-ext-zero-exp") - << "Infer zero : " << on << " == " << n << std::endl; - // minimize explanation if a substitution+rewrite results in zero - const std::set<Node> vars(on.begin(), on.end()); - - for (unsigned i = 0, size = exp.size(); i < size; i++) - { - Trace("nl-ext-zero-exp") - << " exp[" << i << "] = " << exp[i] << std::endl; - std::vector<Node> eqs; - if (exp[i].getKind() == EQUAL) - { - eqs.push_back(exp[i]); - } - else if (exp[i].getKind() == AND) - { - for (const Node& ec : exp[i]) - { - if (ec.getKind() == EQUAL) - { - eqs.push_back(ec); - } - } - } - - for (unsigned j = 0; j < eqs.size(); j++) - { - for (unsigned r = 0; r < 2; r++) - { - if (eqs[j][r] == d_zero && vars.find(eqs[j][1 - r]) != vars.end()) - { - Trace("nl-ext-zero-exp") - << "...single exp : " << eqs[j] << std::endl; - return std::make_pair(true, eqs[j]); - } - } - } - } - } - return std::make_pair(true, Node::null()); -} - void NonlinearExtension::sendLemmas(const std::vector<NlLemma>& out) { for (const NlLemma& nlem : out) diff --git a/src/theory/arith/nl/nonlinear_extension.h b/src/theory/arith/nl/nonlinear_extension.h index d035b1056..41f24e769 100644 --- a/src/theory/arith/nl/nonlinear_extension.h +++ b/src/theory/arith/nl/nonlinear_extension.h @@ -27,6 +27,7 @@ #include "expr/kind.h" #include "expr/node.h" #include "theory/arith/nl/cad_solver.h" +#include "theory/arith/nl/ext_theory_callback.h" #include "theory/arith/nl/iand_solver.h" #include "theory/arith/nl/nl_lemma_utils.h" #include "theory/arith/nl/nl_model.h" @@ -77,48 +78,6 @@ class NonlinearExtension * Does non-context dependent setup for a node connected to a theory. */ void preRegisterTerm(TNode n); - /** Get current substitution - * - * This function and the one below are - * used for context-dependent - * simplification, see Section 3.1 of - * "Designing Theory Solvers with Extensions" - * by Reynolds et al. FroCoS 2017. - * - * effort : an identifier indicating the stage where - * we are performing context-dependent simplification, - * vars : a set of arithmetic variables. - * - * This function populates subs and exp, such that for 0 <= i < vars.size(): - * ( exp[vars[i]] ) => vars[i] = subs[i] - * where exp[vars[i]] is a set of assertions - * that hold in the current context. We call { vars -> subs } a "derivable - * substituion" (see Reynolds et al. FroCoS 2017). - */ - bool getCurrentSubstitution(int effort, - const std::vector<Node>& vars, - std::vector<Node>& subs, - std::map<Node, std::vector<Node>>& exp); - /** Is the term n in reduced form? - * - * Used for context-dependent simplification. - * - * effort : an identifier indicating the stage where - * we are performing context-dependent simplification, - * on : the original term that we reduced to n, - * exp : an explanation such that ( exp => on = n ). - * - * We return a pair ( b, exp' ) such that - * if b is true, then: - * n is in reduced form - * if exp' is non-null, then ( exp' => on = n ) - * The second part of the pair is used for constructing - * minimal explanations for context-dependent simplifications. - */ - std::pair<bool, Node> isExtfReduced(int effort, - Node n, - Node on, - const std::vector<Node>& exp) const; /** Check at effort level e. * * This call may result in (possibly multiple) calls to d_out->lemma(...) @@ -300,6 +259,8 @@ class NonlinearExtension * (modelBasedRefinement). This counter is used for interleaving strategies. */ unsigned d_checkCounter; + /** The callback for the extended theory below */ + NlExtTheoryCallback d_extTheoryCb; /** Extended theory, responsible for context-dependent simplification. */ ExtTheory d_extTheory; /** The non-linear model object diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index 762634ce7..fbf25705c 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -117,14 +117,6 @@ TrustNode TheoryArith::explain(TNode n) return TrustNode::mkTrustPropExp(n, exp, nullptr); } -bool TheoryArith::getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) { - return d_internal->getCurrentSubstitution( effort, vars, subs, exp ); -} - -bool TheoryArith::isExtfReduced( int effort, Node n, Node on, std::vector< Node >& exp ) { - return d_internal->isExtfReduced( effort, n, on, exp ); -} - void TheoryArith::propagate(Effort e) { d_internal->propagate(e); } diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index 6adf8f66a..71a25ac12 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -75,14 +75,6 @@ class TheoryArith : public Theory { bool needsCheckLastEffort() override; void propagate(Effort e) override; TrustNode explain(TNode n) override; - bool getCurrentSubstitution(int effort, - std::vector<Node>& vars, - std::vector<Node>& subs, - std::map<Node, std::vector<Node> >& exp) override; - bool isExtfReduced(int effort, - Node n, - Node on, - std::vector<Node>& exp) override; bool collectModelInfo(TheoryModel* m) override; diff --git a/src/theory/arith/theory_arith_private.cpp b/src/theory/arith/theory_arith_private.cpp index 8a780116c..1b49b7350 100644 --- a/src/theory/arith/theory_arith_private.cpp +++ b/src/theory/arith/theory_arith_private.cpp @@ -3877,31 +3877,6 @@ Node TheoryArithPrivate::explain(TNode n) } } -bool TheoryArithPrivate::getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) { - if (d_nonlinearExtension != nullptr) - { - return d_nonlinearExtension->getCurrentSubstitution( effort, vars, subs, exp ); - }else{ - return false; - } -} - -bool TheoryArithPrivate::isExtfReduced(int effort, Node n, Node on, - std::vector<Node>& exp) { - if (d_nonlinearExtension != nullptr) - { - std::pair<bool, Node> reduced = - d_nonlinearExtension->isExtfReduced(effort, n, on, exp); - if (!reduced.second.isNull()) { - exp.clear(); - exp.push_back(reduced.second); - } - return reduced.first; - } else { - return false; // d_containing.isExtfReduced( effort, n, on ); - } -} - void TheoryArithPrivate::propagate(Theory::Effort e) { // This uses model values for safety. Disable for now. if (d_qflraStatus == Result::SAT diff --git a/src/theory/arith/theory_arith_private.h b/src/theory/arith/theory_arith_private.h index d96b5e2d3..d0428f2ef 100644 --- a/src/theory/arith/theory_arith_private.h +++ b/src/theory/arith/theory_arith_private.h @@ -452,8 +452,6 @@ public: bool needsCheckLastEffort(); void propagate(Theory::Effort e); Node explain(TNode n); - bool getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ); - bool isExtfReduced( int effort, Node n, Node on, std::vector< Node >& exp ); Rational deltaValueForTotalOrder() const; diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp index 38c5cb482..b341b0671 100644 --- a/src/theory/bv/bv_subtheory_core.cpp +++ b/src/theory/bv/bv_subtheory_core.cpp @@ -31,7 +31,65 @@ using namespace CVC4::theory; using namespace CVC4::theory::bv; using namespace CVC4::theory::bv::utils; -CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt) +bool CoreSolverExtTheoryCallback::getCurrentSubstitution( + int effort, + const std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node> >& exp) +{ + if (d_equalityEngine == nullptr) + { + return false; + } + // get the constant equivalence classes + bool retVal = false; + for (const Node& n : vars) + { + if (d_equalityEngine->hasTerm(n)) + { + Node nr = d_equalityEngine->getRepresentative(n); + if (nr.isConst()) + { + subs.push_back(nr); + exp[n].push_back(n.eqNode(nr)); + retVal = true; + } + else + { + subs.push_back(n); + } + } + else + { + subs.push_back(n); + } + } + // return true if the substitution is non-trivial + return retVal; +} + +bool CoreSolverExtTheoryCallback::getReduction(int effort, + Node n, + Node& nr, + bool& satDep) +{ + Trace("bv-ext") << "TheoryBV::checkExt : non-reduced : " << n << std::endl; + if (n.getKind() == kind::BITVECTOR_TO_NAT) + { + nr = utils::eliminateBv2Nat(n); + satDep = false; + return true; + } + else if (n.getKind() == kind::INT_TO_BITVECTOR) + { + nr = utils::eliminateInt2Bv(n); + satDep = false; + return true; + } + return false; +} + +CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv) : SubtheorySolver(c, bv), d_notify(*this), d_isComplete(c, true), @@ -39,9 +97,18 @@ CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt) d_preregisterCalled(false), d_checkCalled(false), d_bv(bv), - d_extTheory(extt), - d_reasons(c) + d_extTheoryCb(), + d_extTheory(new ExtTheory(d_extTheoryCb, + bv->getSatContext(), + bv->getUserContext(), + bv->getOutputChannel())), + d_reasons(c), + d_needsLastCallCheck(false), + d_extf_range_infer(bv->getUserContext()), + d_extf_collapse_infer(bv->getUserContext()) { + d_extTheory->addFunctionKind(kind::BITVECTOR_TO_NAT); + d_extTheory->addFunctionKind(kind::INT_TO_BITVECTOR); } CoreSolver::~CoreSolver() {} @@ -431,3 +498,141 @@ CoreSolver::Statistics::Statistics() CoreSolver::Statistics::~Statistics() { smtStatisticsRegistry()->unregisterStat(&d_numCallstoCheck); } + +void CoreSolver::checkExtf(Theory::Effort e) +{ + if (e == Theory::EFFORT_LAST_CALL) + { + std::vector<Node> nred = d_extTheory->getActive(); + doExtfReductions(nred); + } + Assert(e == Theory::EFFORT_FULL); + // do inferences (adds external lemmas) TODO: this can be improved to add + // internal inferences + std::vector<Node> nred; + if (d_extTheory->doInferences(0, nred)) + { + return; + } + d_needsLastCallCheck = false; + if (!nred.empty()) + { + // other inferences involving bv2nat, int2bv + if (options::bvAlgExtf()) + { + if (doExtfInferences(nred)) + { + return; + } + } + if (!options::bvLazyReduceExtf()) + { + if (doExtfReductions(nred)) + { + return; + } + } + else + { + d_needsLastCallCheck = true; + } + } +} + +bool CoreSolver::needsCheckLastEffort() const { return d_needsLastCallCheck; } + +bool CoreSolver::doExtfInferences(std::vector<Node>& terms) +{ + NodeManager* nm = NodeManager::currentNM(); + bool sentLemma = false; + eq::EqualityEngine* ee = d_equalityEngine; + std::map<Node, Node> op_map; + for (unsigned j = 0; j < terms.size(); j++) + { + TNode n = terms[j]; + Assert(n.getKind() == kind::BITVECTOR_TO_NAT + || n.getKind() == kind::INT_TO_BITVECTOR); + if (n.getKind() == kind::BITVECTOR_TO_NAT) + { + // range lemmas + if (d_extf_range_infer.find(n) == d_extf_range_infer.end()) + { + d_extf_range_infer.insert(n); + unsigned bvs = n[0].getType().getBitVectorSize(); + Node min = nm->mkConst(Rational(0)); + Node max = nm->mkConst(Rational(Integer(1).multiplyByPow2(bvs))); + Node lem = nm->mkNode(kind::AND, + nm->mkNode(kind::GEQ, n, min), + nm->mkNode(kind::LT, n, max)); + Trace("bv-extf-lemma") + << "BV extf lemma (range) : " << lem << std::endl; + d_bv->getOutputChannel().lemma(lem); + sentLemma = true; + } + } + Node r = (ee && ee->hasTerm(n[0])) ? ee->getRepresentative(n[0]) : n[0]; + op_map[r] = n; + } + for (unsigned j = 0; j < terms.size(); j++) + { + TNode n = terms[j]; + Node r = (ee && ee->hasTerm(n[0])) ? ee->getRepresentative(n) : n; + std::map<Node, Node>::iterator it = op_map.find(r); + if (it != op_map.end()) + { + Node parent = it->second; + // Node cterm = parent[0]==n ? parent : nm->mkNode( parent.getOperator(), + // n ); + Node cterm = parent[0].eqNode(n); + Trace("bv-extf-lemma-debug") + << "BV extf collapse based on : " << cterm << std::endl; + if (d_extf_collapse_infer.find(cterm) == d_extf_collapse_infer.end()) + { + d_extf_collapse_infer.insert(cterm); + + Node t = n[0]; + if (t.getType() == parent.getType()) + { + if (n.getKind() == kind::INT_TO_BITVECTOR) + { + Assert(t.getType().isInteger()); + // congruent modulo 2^( bv width ) + unsigned bvs = n.getType().getBitVectorSize(); + Node coeff = nm->mkConst(Rational(Integer(1).multiplyByPow2(bvs))); + Node k = nm->mkSkolem( + "int_bv_cong", t.getType(), "for int2bv/bv2nat congruence"); + t = nm->mkNode(kind::PLUS, t, nm->mkNode(kind::MULT, coeff, k)); + } + Node lem = parent.eqNode(t); + + if (parent[0] != n) + { + Assert(ee->areEqual(parent[0], n)); + lem = nm->mkNode(kind::IMPLIES, parent[0].eqNode(n), lem); + } + // this handles inferences of the form, e.g.: + // ((_ int2bv w) (bv2nat x)) == x (if x is bit-width w) + // (bv2nat ((_ int2bv w) x)) == x + k*2^w for some k + Trace("bv-extf-lemma") + << "BV extf lemma (collapse) : " << lem << std::endl; + d_bv->getOutputChannel().lemma(lem); + sentLemma = true; + } + } + Trace("bv-extf-lemma-debug") + << "BV extf f collapse based on : " << cterm << std::endl; + } + } + return sentLemma; +} + +bool CoreSolver::doExtfReductions(std::vector<Node>& terms) +{ + std::vector<Node> nredr; + if (d_extTheory->doReductions(0, terms, nredr)) + { + return true; + } + Assert(nredr.empty()); + return false; +} diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h index 381804681..32bc36164 100644 --- a/src/theory/bv/bv_subtheory_core.h +++ b/src/theory/bv/bv_subtheory_core.h @@ -31,6 +31,23 @@ namespace theory { namespace bv { class Base; + +/** An extended theory callback used by the core solver */ +class CoreSolverExtTheoryCallback : public ExtTheoryCallback +{ + public: + CoreSolverExtTheoryCallback() : d_equalityEngine(nullptr) {} + /** Get current substitution based on the underlying equality engine. */ + bool getCurrentSubstitution(int effort, + const std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node> >& exp) override; + /** Get reduction. */ + bool getReduction(int effort, Node n, Node& nr, bool& satDep) override; + /** The underlying equality engine */ + eq::EqualityEngine* d_equalityEngine; +}; + /** * Bitvector equality solver */ @@ -83,8 +100,10 @@ class CoreSolver : public SubtheorySolver { TheoryBV* d_bv; /** Pointer to the equality engine of the parent */ eq::EqualityEngine* d_equalityEngine; - /** Pointer to the extended theory module. */ - ExtTheory* d_extTheory; + /** The extended theory callback */ + CoreSolverExtTheoryCallback d_extTheoryCb; + /** Extended theory module, for context-dependent simplification. */ + std::unique_ptr<ExtTheory> d_extTheory; /** To make sure we keep the explanations */ context::CDHashSet<Node, NodeHashFunction> d_reasons; @@ -96,8 +115,38 @@ class CoreSolver : public SubtheorySolver { bool isCompleteForTerm(TNode term, TNodeBoolMap& seen); Statistics d_statistics; + /** Whether we need a last call effort check */ + bool d_needsLastCallCheck; + /** For extended functions */ + context::CDHashSet<Node, NodeHashFunction> d_extf_range_infer; + context::CDHashSet<Node, NodeHashFunction> d_extf_collapse_infer; + + /** do extended function inferences + * + * This method adds lemmas on the output channel of TheoryBV based on + * reasoning about extended functions, such as bv2nat and int2bv. Examples + * of lemmas added by this method include: + * 0 <= ((_ int2bv w) x) < 2^w + * ((_ int2bv w) (bv2nat x)) = x + * (bv2nat ((_ int2bv w) x)) == x + k*2^w + * The purpose of these lemmas is to recognize easy conflicts before fully + * reducing extended functions based on their full semantics. + */ + bool doExtfInferences(std::vector<Node>& terms); + /** do extended function reductions + * + * This method adds lemmas on the output channel of TheoryBV based on + * reducing all extended function applications that are preregistered to + * this theory and have not already been reduced by context-dependent + * simplification (see theory/ext_theory.h). Examples of lemmas added by + * this method include: + * (bv2nat x) = (ite ((_ extract w w-1) x) 2^{w-1} 0) + ... + + * (ite ((_ extract 1 0) x) 1 0) + */ + bool doExtfReductions(std::vector<Node>& terms); + public: - CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt); + CoreSolver(context::Context* c, TheoryBV* bv); ~CoreSolver(); bool needsEqualityEngine(EeSetupInfo& esi); void finishInit(); @@ -111,9 +160,11 @@ class CoreSolver : public SubtheorySolver { EqualityStatus getEqualityStatus(TNode a, TNode b) override; bool hasTerm(TNode node) const; void addTermToEqualityEngine(TNode node); + /** check extended functions at the given effort */ + void checkExtf(Theory::Effort e); + bool needsCheckLastEffort() const; }; - } } } diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index d6492f177..815656d8f 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -29,7 +29,6 @@ #include "theory/bv/theory_bv_rewrite_rules_simplification.h" #include "theory/bv/theory_bv_rewriter.h" #include "theory/bv/theory_bv_utils.h" -#include "theory/ext_theory.h" #include "theory/theory_model.h" #include "theory/valuation.h" @@ -63,18 +62,12 @@ TheoryBV::TheoryBV(context::Context* c, d_invalidateModelCache(c, true), d_literalsToPropagate(c), d_literalsToPropagateIndex(c, 0), - d_extTheory(new ExtTheory(this)), d_propagatedBy(c), d_eagerSolver(), d_abstractionModule(new AbstractionModule(getStatsPrefix(THEORY_BV))), d_calledPreregister(false), - d_needsLastCallCheck(false), - d_extf_range_infer(u), - d_extf_collapse_infer(u), d_state(c, u, valuation) { - d_extTheory->addFunctionKind(kind::BITVECTOR_TO_NAT); - d_extTheory->addFunctionKind(kind::INT_TO_BITVECTOR); if (options::bitblastMode() == options::BitblastMode::EAGER) { d_eagerSolver.reset(new EagerBitblastSolver(c, this)); @@ -83,7 +76,7 @@ TheoryBV::TheoryBV(context::Context* c, if (options::bitvectorEqualitySolver()) { - d_subtheories.emplace_back(new CoreSolver(c, this, d_extTheory.get())); + d_subtheories.emplace_back(new CoreSolver(c, this)); d_subtheoryMap[SUB_CORE] = d_subtheories.back().get(); } @@ -331,8 +324,12 @@ void TheoryBV::check(Effort e) //last call : do reductions on extended bitvector functions if (e == Theory::EFFORT_LAST_CALL) { - std::vector<Node> nred = d_extTheory->getActive(); - doExtfReductions(nred); + CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE]; + if (core) + { + // check extended functions at last call effort + core->checkExtf(e); + } return; } @@ -414,131 +411,24 @@ void TheoryBV::check(Effort e) //check extended functions if (Theory::fullEffort(e)) { - //do inferences (adds external lemmas) TODO: this can be improved to add internal inferences - std::vector< Node > nred; - if (d_extTheory->doInferences(0, nred)) + CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE]; + if (core) { - return; - } - d_needsLastCallCheck = false; - if( !nred.empty() ){ - //other inferences involving bv2nat, int2bv - if( options::bvAlgExtf() ){ - if( doExtfInferences( nred ) ){ - return; - } - } - if( !options::bvLazyReduceExtf() ){ - if( doExtfReductions( nred ) ){ - return; - } - } - else - { - d_needsLastCallCheck = true; - } + // check extended functions at full effort + core->checkExtf(e); } } } -bool TheoryBV::doExtfInferences(std::vector<Node>& terms) +bool TheoryBV::needsCheckLastEffort() { - NodeManager* nm = NodeManager::currentNM(); - bool sentLemma = false; - eq::EqualityEngine* ee = getEqualityEngine(); - std::map<Node, Node> op_map; - for (unsigned j = 0; j < terms.size(); j++) - { - TNode n = terms[j]; - Assert(n.getKind() == kind::BITVECTOR_TO_NAT - || n.getKind() == kind::INT_TO_BITVECTOR); - if (n.getKind() == kind::BITVECTOR_TO_NAT) - { - // range lemmas - if (d_extf_range_infer.find(n) == d_extf_range_infer.end()) - { - d_extf_range_infer.insert(n); - unsigned bvs = n[0].getType().getBitVectorSize(); - Node min = nm->mkConst(Rational(0)); - Node max = nm->mkConst(Rational(Integer(1).multiplyByPow2(bvs))); - Node lem = nm->mkNode(kind::AND, - nm->mkNode(kind::GEQ, n, min), - nm->mkNode(kind::LT, n, max)); - Trace("bv-extf-lemma") - << "BV extf lemma (range) : " << lem << std::endl; - d_out->lemma(lem); - sentLemma = true; - } - } - Node r = (ee && ee->hasTerm(n[0])) ? ee->getRepresentative(n[0]) : n[0]; - op_map[r] = n; - } - for (unsigned j = 0; j < terms.size(); j++) - { - TNode n = terms[j]; - Node r = (ee && ee->hasTerm(n[0])) ? ee->getRepresentative(n) : n; - std::map<Node, Node>::iterator it = op_map.find(r); - if (it != op_map.end()) - { - Node parent = it->second; - // Node cterm = parent[0]==n ? parent : nm->mkNode( parent.getOperator(), - // n ); - Node cterm = parent[0].eqNode(n); - Trace("bv-extf-lemma-debug") - << "BV extf collapse based on : " << cterm << std::endl; - if (d_extf_collapse_infer.find(cterm) == d_extf_collapse_infer.end()) - { - d_extf_collapse_infer.insert(cterm); - - Node t = n[0]; - if (t.getType() == parent.getType()) - { - if (n.getKind() == kind::INT_TO_BITVECTOR) - { - Assert(t.getType().isInteger()); - // congruent modulo 2^( bv width ) - unsigned bvs = n.getType().getBitVectorSize(); - Node coeff = nm->mkConst(Rational(Integer(1).multiplyByPow2(bvs))); - Node k = nm->mkSkolem( - "int_bv_cong", t.getType(), "for int2bv/bv2nat congruence"); - t = nm->mkNode(kind::PLUS, t, nm->mkNode(kind::MULT, coeff, k)); - } - Node lem = parent.eqNode(t); - - if (parent[0] != n) - { - Assert(ee->areEqual(parent[0], n)); - lem = nm->mkNode(kind::IMPLIES, parent[0].eqNode(n), lem); - } - // this handles inferences of the form, e.g.: - // ((_ int2bv w) (bv2nat x)) == x (if x is bit-width w) - // (bv2nat ((_ int2bv w) x)) == x + k*2^w for some k - Trace("bv-extf-lemma") - << "BV extf lemma (collapse) : " << lem << std::endl; - d_out->lemma(lem); - sentLemma = true; - } - } - Trace("bv-extf-lemma-debug") - << "BV extf f collapse based on : " << cterm << std::endl; - } - } - return sentLemma; -} - -bool TheoryBV::doExtfReductions( std::vector< Node >& terms ) { - std::vector< Node > nredr; - if (d_extTheory->doReductions(0, terms, nredr)) + CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE]; + if (core) { - return true; + return core->needsCheckLastEffort(); } - Assert(nredr.empty()); return false; } - -bool TheoryBV::needsCheckLastEffort() { - return d_needsLastCallCheck; -} bool TheoryBV::collectModelInfo(TheoryModel* m) { Assert(!inConflict()); @@ -595,48 +485,6 @@ void TheoryBV::propagate(Effort e) { } } -bool TheoryBV::getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) { - eq::EqualityEngine * ee = getEqualityEngine(); - if( ee ){ - //get the constant equivalence classes - bool retVal = false; - for( unsigned i=0; i<vars.size(); i++ ){ - Node n = vars[i]; - if( ee->hasTerm( n ) ){ - Node nr = ee->getRepresentative( n ); - if( nr.isConst() ){ - subs.push_back( nr ); - exp[n].push_back( n.eqNode( nr ) ); - retVal = true; - }else{ - subs.push_back( n ); - } - }else{ - subs.push_back( n ); - } - } - //return true if the substitution is non-trivial - return retVal; - } - return false; -} - -int TheoryBV::getReduction(int effort, Node n, Node& nr) -{ - Trace("bv-ext") << "TheoryBV::checkExt : non-reduced : " << n << std::endl; - if (n.getKind() == kind::BITVECTOR_TO_NAT) - { - nr = utils::eliminateBv2Nat(n); - return -1; - } - else if (n.getKind() == kind::INT_TO_BITVECTOR) - { - nr = utils::eliminateInt2Bv(n); - return -1; - } - return 0; -} - Theory::PPAssertStatus TheoryBV::ppAssert(TNode in, SubstitutionMap& outSubstitutions) { diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index 2f63f1a52..7475feccc 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -33,11 +33,7 @@ #include "util/statistics_registry.h" namespace CVC4 { - namespace theory { - -class ExtTheory; - namespace bv { class CoreSolver; @@ -101,12 +97,6 @@ class TheoryBV : public Theory { std::string identify() const override { return std::string("TheoryBV"); } - bool getCurrentSubstitution(int effort, - std::vector<Node>& vars, - std::vector<Node>& subs, - std::map<Node, std::vector<Node>>& exp) override; - int getReduction(int effort, Node n, Node& nr) override; - PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; TrustNode ppRewrite(TNode t) override; @@ -177,9 +167,6 @@ class TheoryBV : public Theory { /** Index of the next literal to propagate */ context::CDO<unsigned> d_literalsToPropagateIndex; - /** Extended theory module, for context-dependent simplification. */ - std::unique_ptr<ExtTheory> d_extTheory; - /** * Keeps a map from nodes to the subtheory that propagated it so that we can explain it * properly. @@ -191,34 +178,6 @@ class TheoryBV : public Theory { std::unique_ptr<AbstractionModule> d_abstractionModule; bool d_calledPreregister; - //for extended functions - bool d_needsLastCallCheck; - context::CDHashSet<Node, NodeHashFunction> d_extf_range_infer; - context::CDHashSet<Node, NodeHashFunction> d_extf_collapse_infer; - /** do extended function inferences - * - * This method adds lemmas on the output channel of TheoryBV based on - * reasoning about extended functions, such as bv2nat and int2bv. Examples - * of lemmas added by this method include: - * 0 <= ((_ int2bv w) x) < 2^w - * ((_ int2bv w) (bv2nat x)) = x - * (bv2nat ((_ int2bv w) x)) == x + k*2^w - * The purpose of these lemmas is to recognize easy conflicts before fully - * reducing extended functions based on their full semantics. - */ - bool doExtfInferences( std::vector< Node >& terms ); - /** do extended function reductions - * - * This method adds lemmas on the output channel of TheoryBV based on - * reducing all extended function applications that are preregistered to - * this theory and have not already been reduced by context-dependent - * simplification (see theory/ext_theory.h). Examples of lemmas added by - * this method include: - * (bv2nat x) = (ite ((_ extract w w-1) x) 2^{w-1} 0) + ... + - * (ite ((_ extract 1 0) x) 1 0) - */ - bool doExtfReductions( std::vector< Node >& terms ); - bool wasPropagatedBySubtheory(TNode literal) const { return d_propagatedBy.find(literal) != d_propagatedBy.end(); } diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 5253414a9..585f13d82 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -1965,10 +1965,6 @@ TNode TheoryDatatypes::getRepresentative( TNode a ){ } } -bool TheoryDatatypes::getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) { - return false; -} - void TheoryDatatypes::printModelDebug( const char* c ){ if(! (Trace.isOn(c))) { return; diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index 37a4f81f7..bf5d33177 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -273,10 +273,6 @@ private: { return std::string("TheoryDatatypes"); } - bool getCurrentSubstitution(int effort, - std::vector<Node>& vars, - std::vector<Node>& subs, - std::map<Node, std::vector<Node> >& exp) override; /** debug print */ void printModelDebug( const char* c ); /** entailment check */ diff --git a/src/theory/ext_theory.cpp b/src/theory/ext_theory.cpp index bdcd5dcff..e8ed60ae4 100644 --- a/src/theory/ext_theory.cpp +++ b/src/theory/ext_theory.cpp @@ -28,13 +28,41 @@ using namespace std; namespace CVC4 { namespace theory { -ExtTheory::ExtTheory(Theory* p, bool cacheEnabled) +bool ExtTheoryCallback::getCurrentSubstitution( + int effort, + const std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node> >& exp) +{ + return false; +} +bool ExtTheoryCallback::isExtfReduced(int effort, + Node n, + Node on, + std::vector<Node>& exp) +{ + return n.isConst(); +} +bool ExtTheoryCallback::getReduction(int effort, + Node n, + Node& nr, + bool& isSatDep) +{ + return false; +} + +ExtTheory::ExtTheory(ExtTheoryCallback& p, + context::Context* c, + context::UserContext* u, + OutputChannel& out, + bool cacheEnabled) : d_parent(p), - d_ext_func_terms(p->getSatContext()), - d_ci_inactive(p->getUserContext()), - d_has_extf(p->getSatContext()), - d_lemmas(p->getUserContext()), - d_pp_lemmas(p->getUserContext()), + d_out(out), + d_ext_func_terms(c), + d_ci_inactive(u), + d_has_extf(c), + d_lemmas(u), + d_pp_lemmas(u), d_cacheEnabled(cacheEnabled) { d_true = NodeManager::currentNM()->mkConst(true); @@ -61,7 +89,6 @@ std::vector<Node> ExtTheory::collectVars(Node n) // (commented below) if (current.getNumChildren() > 0) { - //&& Theory::theoryOf(n)==d_parent->getId() ){ worklist.insert(worklist.end(), current.begin(), current.end()); } else @@ -140,7 +167,7 @@ void ExtTheory::getSubstitutedTerms(int effort, } } } - bool useSubs = d_parent->getCurrentSubstitution(effort, vars, sub, expc); + bool useSubs = d_parent.getCurrentSubstitution(effort, vars, sub, expc); // get the current substitution for all variables Assert(!useSubs || vars.size() == sub.size()); for (const Node& n : terms) @@ -206,8 +233,8 @@ bool ExtTheory::doInferencesInternal(int effort, { Node nr; // note: could do reduction with substitution here - int ret = d_parent->getReduction(effort, n, nr); - if (ret == 0) + bool satDep = false; + if (!d_parent.getReduction(effort, n, nr, satDep)) { nred.push_back(n); } @@ -223,7 +250,7 @@ bool ExtTheory::doInferencesInternal(int effort, addedLemma = true; } } - markReduced(n, ret < 0); + markReduced(n, satDep); } } } @@ -242,7 +269,7 @@ bool ExtTheory::doInferencesInternal(int effort, Node sr = Rewriter::rewrite(sterms[i]); // ask the theory if this term is reduced, e.g. is it constant or it // is a non-extf term. - if (d_parent->isExtfReduced(effort, sr, terms[i], exp[i])) + if (d_parent.isExtfReduced(effort, sr, terms[i], exp[i])) { processed = true; markReduced(terms[i]); @@ -344,7 +371,7 @@ bool ExtTheory::sendLemma(Node lem, bool preprocess) if (d_pp_lemmas.find(lem) == d_pp_lemmas.end()) { d_pp_lemmas.insert(lem); - d_parent->getOutputChannel().lemma(lem, LemmaProperty::PREPROCESS); + d_out.lemma(lem, LemmaProperty::PREPROCESS); return true; } } @@ -353,7 +380,7 @@ bool ExtTheory::sendLemma(Node lem, bool preprocess) if (d_lemmas.find(lem) == d_lemmas.end()) { d_lemmas.insert(lem); - d_parent->getOutputChannel().lemma(lem); + d_out.lemma(lem); return true; } } @@ -403,8 +430,7 @@ void ExtTheory::registerTerm(Node n) { if (d_ext_func_terms.find(n) == d_ext_func_terms.end()) { - Trace("extt-debug") << "Found extended function : " << n << " in " - << d_parent->getId() << std::endl; + Trace("extt-debug") << "Found extended function : " << n << std::endl; d_ext_func_terms[n] = true; d_has_extf = n; d_extf_info[n].d_vars = collectVars(n); @@ -435,13 +461,13 @@ void ExtTheory::registerTermRec(Node n) } // mark reduced -void ExtTheory::markReduced(Node n, bool contextDepend) +void ExtTheory::markReduced(Node n, bool satDep) { Trace("extt-debug") << "Mark reduced " << n << std::endl; registerTerm(n); Assert(d_ext_func_terms.find(n) != d_ext_func_terms.end()); d_ext_func_terms[n] = false; - if (!contextDepend) + if (!satDep) { d_ci_inactive.insert(n); } diff --git a/src/theory/ext_theory.h b/src/theory/ext_theory.h index 2721bc89e..efd24e2c8 100644 --- a/src/theory/ext_theory.h +++ b/src/theory/ext_theory.h @@ -45,6 +45,57 @@ namespace CVC4 { namespace theory { +/** + * A callback class for ExtTheory below. This class is responsible for + * determining how to apply context-dependent simplification. + */ +class ExtTheoryCallback +{ + public: + virtual ~ExtTheoryCallback() {} + /* + * Get current substitution at an effort + * @param effort The effort identifier + * @param vars The variables to get a substitution for + * @param subs The terms to substitute for variables, in order. This vector + * should be updated to one the same size as vars. + * @param exp The map containing the explanation for each variable. Together + * with subs, we have that: + * ( exp[vars[i]] => vars[i] = subs[i] ) holds for all i + * @return true if any (non-identity) substitution was added to subs. + */ + virtual bool getCurrentSubstitution(int effort, + const std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node> >& exp); + + /* + * Is extended function n reduced? This returns true if n is reduced to a + * form that requires no further interaction from the theory. + * + * @param effort The effort identifier + * @param n The term to reduce + * @param on The original form of n, before substitution + * @param exp The explanation of on = n + * @return true if n is reduced. + */ + virtual bool isExtfReduced(int effort, + Node n, + Node on, + std::vector<Node>& exp); + + /** + * Get reduction for node n. + * If return value is true, then n is reduced. + * If satDep is updated to false, then n is reduced independent of the + * SAT context (e.g. by a lemma that persists at this + * user-context level). + * If nr is non-null, then ( n = nr ) should be added as a lemma by caller, + * and return value of this method should be true. + */ + virtual bool getReduction(int effort, Node n, Node& nr, bool& satDep); +}; + /** Extended theory class * * This class is used for constructing generic extensions to theory solvers. @@ -73,7 +124,11 @@ class ExtTheory * * If cacheEnabled is false, we do not cache results of getSubstitutedTerm. */ - ExtTheory(Theory* p, bool cacheEnabled = false); + ExtTheory(ExtTheoryCallback& p, + context::Context* c, + context::UserContext* u, + OutputChannel& out, + bool cacheEnabled = false); virtual ~ExtTheory() {} /** Tells this class to treat terms with Kind k as extended functions */ void addFunctionKind(Kind k) { d_extf_kind[k] = true; } @@ -93,10 +148,10 @@ class ExtTheory void registerTermRec(Node n); /** set n as reduced/inactive * - * If contextDepend = false, then n remains inactive in the duration of this + * If satDep = false, then n remains inactive in the duration of this * user-context level */ - void markReduced(Node n, bool contextDepend = true); + void markReduced(Node n, bool satDep = true); /** * Mark that a and b are congruent terms. This sets b inactive, and sets a to * inactive if b was inactive. @@ -194,10 +249,12 @@ class ExtTheory std::vector<Node>& nred, bool batch, bool isRed); - /** send lemma on the output channel of d_parent */ + /** send lemma on the output channel */ bool sendLemma(Node lem, bool preprocess = false); - /** reference to the underlying theory */ - Theory* d_parent; + /** reference to the callback */ + ExtTheoryCallback& d_parent; + /** Reference to the output channel we are using */ + OutputChannel& d_out; /** the true node */ Node d_true; /** extended function terms, map to whether they are active */ diff --git a/src/theory/strings/extf_solver.cpp b/src/theory/strings/extf_solver.cpp index b028da38a..6fcd5785d 100644 --- a/src/theory/strings/extf_solver.cpp +++ b/src/theory/strings/extf_solver.cpp @@ -700,6 +700,23 @@ std::vector<Node> ExtfSolver::getActive(Kind k) const return d_extt.getActive(k); } +bool StringsExtfCallback::getCurrentSubstitution( + int effort, + const std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node> >& exp) +{ + Trace("strings-subs") << "getCurrentSubstitution, effort = " << effort + << std::endl; + for (const Node& v : vars) + { + Trace("strings-subs") << " get subs for " << v << "..." << std::endl; + Node s = d_esolver->getCurrentSubstitutionFor(effort, v, exp[v]); + subs.push_back(s); + } + return true; +} + } // namespace strings } // namespace theory } // namespace CVC4 diff --git a/src/theory/strings/extf_solver.h b/src/theory/strings/extf_solver.h index 4ba38bfc6..5b11b6faf 100644 --- a/src/theory/strings/extf_solver.h +++ b/src/theory/strings/extf_solver.h @@ -214,6 +214,23 @@ class ExtfSolver NodeSet d_reduced; }; +/** An extended theory callback */ +class StringsExtfCallback : public ExtTheoryCallback +{ + public: + StringsExtfCallback() : d_esolver(nullptr) {} + /** + * Get current substitution based on the underlying extended function + * solver. + */ + bool getCurrentSubstitution(int effort, + const std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node> >& exp) override; + /** The extended function solver */ + ExtfSolver* d_esolver; +}; + } // namespace strings } // namespace theory } // namespace CVC4 diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index 3e60cbc44..f248cb330 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -45,7 +45,8 @@ TheoryStrings::TheoryStrings(context::Context* c, d_statistics(), d_state(c, u, d_valuation), d_termReg(d_state, out, d_statistics, nullptr), - d_extTheory(this), + d_extTheoryCb(), + d_extTheory(d_extTheoryCb, c, u, out), d_im(*this, d_state, d_termReg, d_extTheory, d_statistics, pnm), d_rewriter(&d_statistics.d_rewrites), d_bsolver(d_state, d_im), @@ -75,6 +76,9 @@ TheoryStrings::TheoryStrings(context::Context* c, d_cardSize = utils::getAlphabetCardinality(); + // set up the extended function callback + d_extTheoryCb.d_esolver = &d_esolver; + ProofChecker* pc = pnm != nullptr ? pnm->getChecker() : nullptr; if (pc != nullptr) { @@ -202,18 +206,6 @@ TrustNode TheoryStrings::explain(TNode literal) return TrustNode::mkTrustPropExp(literal, ret, nullptr); } -bool TheoryStrings::getCurrentSubstitution( int effort, std::vector< Node >& vars, - std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) { - Trace("strings-subs") << "getCurrentSubstitution, effort = " << effort << std::endl; - for( unsigned i=0; i<vars.size(); i++ ){ - Node n = vars[i]; - Trace("strings-subs") << " get subs for " << n << "..." << std::endl; - Node s = d_esolver.getCurrentSubstitutionFor(effort, n, exp[n]); - subs.push_back(s); - } - return true; -} - void TheoryStrings::presolve() { Debug("strings-presolve") << "TheoryStrings::Presolving : get fmf options " << (options::stringFMF() ? "true" : "false") << std::endl; d_strat.initializeStrategy(); diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index 0f59e73dc..cbe6000bf 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -86,11 +86,6 @@ class TheoryStrings : public Theory { std::string identify() const override; /** Explain */ TrustNode explain(TNode literal) override; - /** Get current substitution */ - bool getCurrentSubstitution(int effort, - std::vector<Node>& vars, - std::vector<Node>& subs, - std::map<Node, std::vector<Node> >& exp) override; /** presolve */ void presolve() override; /** shutdown */ @@ -262,6 +257,8 @@ class TheoryStrings : public Theory { SolverState d_state; /** The term registry for this theory */ TermRegistry d_termReg; + /** The extended theory callback */ + StringsExtfCallback d_extTheoryCb; /** Extended theory, responsible for context-dependent simplification. */ ExtTheory d_extTheory; /** The (custom) output channel of the theory of strings */ diff --git a/src/theory/theory.h b/src/theory/theory.h index 176d4b672..77652f874 100644 --- a/src/theory/theory.h +++ b/src/theory/theory.h @@ -897,30 +897,6 @@ class Theory { * E |= lit in the theory. */ virtual std::pair<bool, Node> entailmentCheck(TNode lit); - - /* get current substitution at an effort - * input : vars - * output : subs, exp - * where ( exp[vars[i]] => vars[i] = subs[i] ) holds for all i - */ - virtual bool getCurrentSubstitution(int effort, std::vector<Node>& vars, - std::vector<Node>& subs, - std::map<Node, std::vector<Node> >& exp) { - return false; - } - - /* is extended function reduced */ - virtual bool isExtfReduced( int effort, Node n, Node on, std::vector< Node >& exp ) { return n.isConst(); } - - /** - * Get reduction for node - * If return value is not 0, then n is reduced. - * If return value <0 then n is reduced SAT-context-independently (e.g. by a - * lemma that persists at this user-context level). - * If nr is non-null, then ( n = nr ) should be added as a lemma by caller, - * and return value should be <0. - */ - virtual int getReduction( int effort, Node n, Node& nr ) { return 0; } };/* class Theory */ std::ostream& operator<<(std::ostream& os, theory::Theory::Effort level); |