From 92cdcc09e9a8bece8053c3aba9e68d0028b41a8e Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Thu, 17 Sep 2020 15:54:02 +0200 Subject: Use new inference manager in transcendental solver (#5022) This refactors the transcendental solver to add lemmas to the new inference manager instead of using the old lemma collection scheme. --- src/theory/arith/inference_manager.cpp | 12 ++++++ src/theory/arith/inference_manager.h | 3 ++ src/theory/arith/nl/nonlinear_extension.cpp | 50 ++++++++++++------------ src/theory/arith/nl/transcendental_solver.cpp | 56 ++++++++++++--------------- src/theory/arith/nl/transcendental_solver.h | 27 ++++++------- 5 files changed, 79 insertions(+), 69 deletions(-) (limited to 'src') diff --git a/src/theory/arith/inference_manager.cpp b/src/theory/arith/inference_manager.cpp index d4c5d17c5..5c1602a1a 100644 --- a/src/theory/arith/inference_manager.cpp +++ b/src/theory/arith/inference_manager.cpp @@ -32,6 +32,8 @@ InferenceManager::InferenceManager(TheoryArith& ta, void InferenceManager::addPendingArithLemma(std::unique_ptr lemma, bool isWaiting) { + Trace("arith::infman") << "Add " << lemma->d_inference << " " << lemma->d_node + << (isWaiting ? " as waiting" : "") << std::endl; lemma->d_node = Rewriter::rewrite(lemma->d_node); if (hasCachedLemma(lemma->d_node, lemma->d_property)) { @@ -77,6 +79,9 @@ void InferenceManager::flushWaitingLemmas() { for (auto& lem : d_waitingLem) { + Trace("arith::infman") << "Flush waiting lemma to pending: " + << lem->d_inference << " " << lem->d_node + << std::endl; d_pendingLem.emplace_back(std::move(lem)); } d_waitingLem.clear(); @@ -84,6 +89,8 @@ void InferenceManager::flushWaitingLemmas() void InferenceManager::addConflict(const Node& conf, InferenceId inftype) { + Trace("arith::infman") << "Adding conflict: " << inftype << " " << conf + << std::endl; conflict(Rewriter::rewrite(conf)); } @@ -92,6 +99,11 @@ bool InferenceManager::hasUsed() const return hasSent() || hasPending(); } +bool InferenceManager::hasWaitingLemma() const +{ + return !d_waitingLem.empty(); +} + std::size_t InferenceManager::numWaitingLemmas() const { return d_waitingLem.size(); diff --git a/src/theory/arith/inference_manager.h b/src/theory/arith/inference_manager.h index e1e386bec..f4806cc9a 100644 --- a/src/theory/arith/inference_manager.h +++ b/src/theory/arith/inference_manager.h @@ -87,6 +87,9 @@ class InferenceManager : public InferenceManagerBuffered */ bool hasUsed() const; + /** Checks whether we have waiting lemmas. */ + bool hasWaitingLemma() const; + /** Returns the number of pending lemmas. */ std::size_t numWaitingLemmas() const; diff --git a/src/theory/arith/nl/nonlinear_extension.cpp b/src/theory/arith/nl/nonlinear_extension.cpp index 3bf547ceb..df3a304be 100644 --- a/src/theory/arith/nl/nonlinear_extension.cpp +++ b/src/theory/arith/nl/nonlinear_extension.cpp @@ -44,7 +44,7 @@ NonlinearExtension::NonlinearExtension(TheoryArith& containing, containing.getUserContext(), containing.getOutputChannel()), d_model(containing.getSatContext()), - d_trSlv(d_model), + d_trSlv(d_im, d_model), d_nlSlv(containing, d_model), d_cadSlv(d_im, d_model), d_iandSlv(containing, d_model), @@ -386,9 +386,7 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, // initialize the non-linear solver d_nlSlv.initLastCall(assertions, false_asserts, xts); // initialize the trancendental function solver - d_trSlv.initLastCall(assertions, false_asserts, xts, lemmas); - // process lemmas that may have been generated by the transcendental solver - filterLemmas(lemmas, lems); + d_trSlv.initLastCall(assertions, false_asserts, xts); } if (options::nlCad()) { @@ -398,11 +396,12 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, // init last call with IAND d_iandSlv.initLastCall(assertions, false_asserts, xts); - if (!lems.empty()) + if (d_im.hasUsed() || !lems.empty()) { - Trace("nl-ext") << " ...finished with " << lems.size() + unsigned count = lems.size() + d_im.numPendingLemmas() + d_im.numSentLemmas(); + Trace("nl-ext") << " ...finished with " << count << " new lemmas during registration." << std::endl; - return lems.size(); + return count; } //----------------------------------- possibly split on zero @@ -423,13 +422,13 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, if (options::nlExt()) { // functions - lemmas = d_trSlv.checkTranscendentalInitialRefine(); - filterLemmas(lemmas, lems); - if (!lems.empty()) + d_trSlv.checkTranscendentalInitialRefine(); + if (d_im.hasUsed()) { - Trace("nl-ext") << " ...finished with " << lems.size() << " new lemmas." + unsigned count = lems.size() + d_im.numPendingLemmas() + d_im.numSentLemmas(); + Trace("nl-ext") << " ...finished with " << count << " new lemmas." << std::endl; - return lems.size(); + return count; } } //-----------------------------------initial lemmas for iand @@ -456,13 +455,13 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, } //-----------------------------------monotonicity of transdental functions - lemmas = d_trSlv.checkTranscendentalMonotonic(); - filterLemmas(lemmas, lems); - if (!lems.empty()) + d_trSlv.checkTranscendentalMonotonic(); + if (d_im.hasUsed()) { - Trace("nl-ext") << " ...finished with " << lems.size() << " new lemmas." + unsigned count = lems.size() + d_im.numPendingLemmas() + d_im.numSentLemmas(); + Trace("nl-ext") << " ...finished with " << count << " new lemmas." << std::endl; - return lems.size(); + return count; } //------------------------lemmas based on magnitude of non-zero monomials @@ -551,8 +550,7 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, } if (options::nlExtTfTangentPlanes()) { - lemmas = d_trSlv.checkTranscendentalTangentPlanes(); - filterLemmas(lemmas, wlems); + d_trSlv.checkTranscendentalTangentPlanes(); } } if (options::nlCad()) @@ -572,8 +570,9 @@ int NonlinearExtension::checkLastCall(const std::vector& assertions, lemmas = d_iandSlv.checkFullRefine(); filterLemmas(lemmas, wlems); - Trace("nl-ext") << " ...finished with " << wlems.size() << " waiting lemmas." - << std::endl; + Trace("nl-ext") << " ...finished with " + << (wlems.size() + d_im.numWaitingLemmas()) + << " waiting lemmas." << std::endl; return 0; } @@ -614,6 +613,7 @@ void NonlinearExtension::check(Theory::Effort e) d_im.doPendingFacts(); d_im.doPendingLemmas(); d_im.doPendingPhaseRequirements(); + d_im.reset(); return; } // Otherwise, we will answer SAT. The values that we approximated are @@ -728,7 +728,7 @@ bool NonlinearExtension::modelBasedRefinement(std::vector& mlems) { complete_status = num_shared_wrong_value > 0 ? -1 : 0; checkLastCall(assertions, false_asserts, xts, mlems, wlems); - if (!mlems.empty()) + if (!mlems.empty() || d_im.hasSentLemma() || d_im.hasPendingLemma()) { return true; } @@ -768,10 +768,12 @@ bool NonlinearExtension::modelBasedRefinement(std::vector& mlems) if (complete_status != 1) { // flush the waiting lemmas - if (!wlems.empty()) + if (!wlems.empty() || d_im.hasWaitingLemma()) { + std::size_t count = wlems.size() + d_im.numWaitingLemmas(); mlems.insert(mlems.end(), wlems.begin(), wlems.end()); - Trace("nl-ext") << "...added " << wlems.size() << " waiting lemmas." + d_im.flushWaitingLemmas(); + Trace("nl-ext") << "...added " << count << " waiting lemmas." << std::endl; return true; } diff --git a/src/theory/arith/nl/transcendental_solver.cpp b/src/theory/arith/nl/transcendental_solver.cpp index d075d5037..b22cf990e 100644 --- a/src/theory/arith/nl/transcendental_solver.cpp +++ b/src/theory/arith/nl/transcendental_solver.cpp @@ -31,7 +31,7 @@ namespace theory { namespace arith { namespace nl { -TranscendentalSolver::TranscendentalSolver(NlModel& m) : d_model(m) +TranscendentalSolver::TranscendentalSolver(InferenceManager& im, NlModel& m) : d_im(im), d_model(m) { NodeManager* nm = NodeManager::currentNM(); d_true = nm->mkConst(true); @@ -49,8 +49,7 @@ TranscendentalSolver::~TranscendentalSolver() {} void TranscendentalSolver::initLastCall(const std::vector& assertions, const std::vector& false_asserts, - const std::vector& xts, - std::vector& lems) + const std::vector& xts) { d_funcCongClass.clear(); d_funcMap.clear(); @@ -136,7 +135,7 @@ void TranscendentalSolver::initLastCall(const std::vector& assertions, } Node expn = exp.size() == 1 ? exp[0] : nm->mkNode(AND, exp); Node cong_lemma = nm->mkNode(OR, expn.negate(), a.eqNode(aa)); - lems.emplace_back(cong_lemma, InferenceId::NL_CONGRUENCE); + d_im.addPendingArithLemma(cong_lemma, InferenceId::NL_CONGRUENCE); } } else @@ -160,10 +159,10 @@ void TranscendentalSolver::initLastCall(const std::vector& assertions, if (needPi && d_pi.isNull()) { mkPi(); - getCurrentPiBounds(lems); + getCurrentPiBounds(); } - if (!lems.empty()) + if (d_im.hasUsed()) { return; } @@ -212,9 +211,8 @@ void TranscendentalSolver::initLastCall(const std::vector& assertions, // note we must do preprocess on this lemma Trace("nl-ext-lemma") << "NonlinearExtension::Lemma : purify : " << lem << std::endl; - NlLemma nlem( - lem, LemmaProperty::PREPROCESS, nullptr, InferenceId::NL_T_PURIFY_ARG); - lems.emplace_back(nlem); + NlLemma nlem(lem, LemmaProperty::PREPROCESS, nullptr, InferenceId::NL_T_PURIFY_ARG); + d_im.addPendingArithLemma(nlem); } if (Trace.isOn("nl-ext-mv")) @@ -363,19 +361,18 @@ void TranscendentalSolver::mkPi() } } -void TranscendentalSolver::getCurrentPiBounds(std::vector& lemmas) +void TranscendentalSolver::getCurrentPiBounds() { NodeManager* nm = NodeManager::currentNM(); Node pi_lem = nm->mkNode(AND, nm->mkNode(GEQ, d_pi, d_pi_bound[0]), nm->mkNode(LEQ, d_pi, d_pi_bound[1])); - lemmas.emplace_back(pi_lem, InferenceId::NL_T_PI_BOUND); + d_im.addPendingArithLemma(pi_lem, InferenceId::NL_T_PI_BOUND); } -std::vector TranscendentalSolver::checkTranscendentalInitialRefine() +void TranscendentalSolver::checkTranscendentalInitialRefine() { NodeManager* nm = NodeManager::currentNM(); - std::vector lemmas; Trace("nl-ext") << "Get initial refinement lemmas for transcendental functions..." << std::endl; @@ -454,18 +451,15 @@ std::vector TranscendentalSolver::checkTranscendentalInitialRefine() } if (!lem.isNull()) { - lemmas.emplace_back(lem, InferenceId::NL_T_INIT_REFINE); + d_im.addPendingArithLemma(lem, InferenceId::NL_T_INIT_REFINE); } } } } - - return lemmas; } -std::vector TranscendentalSolver::checkTranscendentalMonotonic() +void TranscendentalSolver::checkTranscendentalMonotonic() { - std::vector lemmas; Trace("nl-ext") << "Get monotonicity lemmas for transcendental functions..." << std::endl; @@ -630,7 +624,8 @@ std::vector TranscendentalSolver::checkTranscendentalMonotonic() } Trace("nl-ext-tf-mono") << "Monotonicity lemma : " << mono_lem << std::endl; - lemmas.emplace_back(mono_lem, InferenceId::NL_T_MONOTONICITY); + + d_im.addPendingArithLemma(mono_lem, InferenceId::NL_T_MONOTONICITY); } } // store the previous values @@ -642,12 +637,10 @@ std::vector TranscendentalSolver::checkTranscendentalMonotonic() } } } - return lemmas; } -std::vector TranscendentalSolver::checkTranscendentalTangentPlanes() +void TranscendentalSolver::checkTranscendentalTangentPlanes() { - std::vector lemmas; Trace("nl-ext") << "Get tangent plane lemmas for transcendental functions..." << std::endl; // this implements Figure 3 of "Satisfiaility Modulo Transcendental Functions @@ -682,11 +675,13 @@ std::vector TranscendentalSolver::checkTranscendentalTangentPlanes() for (unsigned d = 1; d <= d_taylor_degree; d++) { Trace("nl-ext-tftp") << "- run at degree " << d << "..." << std::endl; - unsigned prev = lemmas.size(); - if (checkTfTangentPlanesFun(tf, d, lemmas)) + unsigned prev = d_im.numPendingLemmas() + d_im.numWaitingLemmas(); + if (checkTfTangentPlanesFun(tf, d)) { Trace("nl-ext-tftp") - << "...fail, #lemmas = " << (lemmas.size() - prev) << std::endl; + << "...fail, #lemmas = " + << (d_im.numPendingLemmas() + d_im.numWaitingLemmas() - prev) + << std::endl; break; } else @@ -696,13 +691,10 @@ std::vector TranscendentalSolver::checkTranscendentalTangentPlanes() } } } - - return lemmas; } bool TranscendentalSolver::checkTfTangentPlanesFun(Node tf, - unsigned d, - std::vector& lemmas) + unsigned d) { NodeManager* nm = NodeManager::currentNM(); Kind k = tf.getKind(); @@ -883,7 +875,7 @@ bool TranscendentalSolver::checkTfTangentPlanesFun(Node tf, << "*** Tangent plane lemma : " << lem << std::endl; Assert(d_model.computeAbstractModelValue(lem) == d_false); // Figure 3 : line 9 - lemmas.emplace_back(lem, InferenceId::NL_T_TANGENT); + d_im.addPendingArithLemma(lem, InferenceId::NL_T_TANGENT, true); } else if (is_secant) { @@ -1017,11 +1009,11 @@ bool TranscendentalSolver::checkTfTangentPlanesFun(Node tf, Assert(!lemmaConj.empty()); Node lem = lemmaConj.size() == 1 ? lemmaConj[0] : nm->mkNode(AND, lemmaConj); - NlLemma nlem(lem, InferenceId::NL_T_SECANT); + NlLemma nlem(lem, LemmaProperty::NONE, nullptr, InferenceId::NL_T_SECANT); // The side effect says that if lem is added, then we should add the // secant point c for (tf,d). nlem.d_secantPoint.push_back(std::make_tuple(tf, d, c)); - lemmas.emplace_back(nlem); + d_im.addPendingArithLemma(nlem, true); } return true; } diff --git a/src/theory/arith/nl/transcendental_solver.h b/src/theory/arith/nl/transcendental_solver.h index c80fa99e6..2ac2ae2f3 100644 --- a/src/theory/arith/nl/transcendental_solver.h +++ b/src/theory/arith/nl/transcendental_solver.h @@ -21,7 +21,7 @@ #include #include "expr/node.h" -#include "theory/arith/nl/nl_lemma_utils.h" +#include "theory/arith/inference_manager.h" #include "theory/arith/nl/nl_model.h" namespace CVC4 { @@ -44,7 +44,7 @@ namespace nl { class TranscendentalSolver { public: - TranscendentalSolver(NlModel& m); + TranscendentalSolver(InferenceManager& im, NlModel& m); ~TranscendentalSolver(); /** init last call @@ -60,8 +60,7 @@ class TranscendentalSolver */ void initLastCall(const std::vector& assertions, const std::vector& false_asserts, - const std::vector& xts, - std::vector& lems); + const std::vector& xts); /** increment taylor degree */ void incrementTaylorDegree(); /** get taylor degree */ @@ -80,7 +79,7 @@ class TranscendentalSolver //-------------------------------------------- lemma schemas /** check transcendental initial refine * - * Returns a set of valid theory lemmas, based on + * Constructs a set of valid theory lemmas, based on * simple facts about transcendental functions. * This mostly follows the initial axioms described in * Section 4 of "Satisfiability @@ -94,11 +93,11 @@ class TranscendentalSolver * exp( x )>0 * x<0 => exp( x )<1 */ - std::vector checkTranscendentalInitialRefine(); + void checkTranscendentalInitialRefine(); /** check transcendental monotonic * - * Returns a set of valid theory lemmas, based on a + * Constructs a set of valid theory lemmas, based on a * lemma scheme that ensures that applications * of transcendental functions respect monotonicity. * @@ -108,11 +107,11 @@ class TranscendentalSolver * PI/2 > x > y > 0 => sin( x ) > sin( y ) * PI > x > y > PI/2 => sin( x ) < sin( y ) */ - std::vector checkTranscendentalMonotonic(); + void checkTranscendentalMonotonic(); /** check transcendental tangent planes * - * Returns a set of valid theory lemmas, based on + * Constructs a set of valid theory lemmas, based on * computing an "incremental linearization" of * transcendental functions based on the model values * of transcendental functions and their arguments. @@ -168,7 +167,8 @@ class TranscendentalSolver * where c1, c2 are rationals (for brevity, omitted here) * such that c1 ~= .277 and c2 ~= 2.032. */ - std::vector checkTranscendentalTangentPlanes(); + void checkTranscendentalTangentPlanes(); + private: /** check transcendental function refinement for tf * * This method is called by the above method for each "master" @@ -186,9 +186,8 @@ class TranscendentalSolver * It returns false if the bounds are not precise enough to add a * secant or tangent plane lemma. */ - bool checkTfTangentPlanesFun(Node tf, unsigned d, std::vector& lems); + bool checkTfTangentPlanesFun(Node tf, unsigned d); //-------------------------------------------- end lemma schemas - private: /** polynomial approximation bounds * * This adds P_l+[x], P_l-[x], P_u+[x], P_u-[x] to pbounds, where x is @@ -268,10 +267,12 @@ class TranscendentalSolver Node getDerivative(Node n, Node x); void mkPi(); - void getCurrentPiBounds(std::vector& lemmas); + void getCurrentPiBounds(); /** Make the node -pi <= a <= pi */ static Node mkValidPhase(Node a, Node pi); + /** The inference manager that we push conflicts and lemmas to. */ + InferenceManager& d_im; /** Reference to the non-linear model object */ NlModel& d_model; /** commonly used terms */ -- cgit v1.2.3