diff options
Diffstat (limited to 'src/theory/arith/nl')
29 files changed, 622 insertions, 354 deletions
diff --git a/src/theory/arith/nl/cad/cdcac.cpp b/src/theory/arith/nl/cad/cdcac.cpp index 2fc77be1b..18ccf7aca 100644 --- a/src/theory/arith/nl/cad/cdcac.cpp +++ b/src/theory/arith/nl/cad/cdcac.cpp @@ -105,16 +105,7 @@ std::vector<CACInterval> CDCAC::getUnsatIntervals(std::size_t cur_variable) { std::vector<CACInterval> res; LazardEvaluation le; - if (options().arith.nlCadLifting - == options::NlCadLiftingMode::LAZARD) - { - for (size_t vid = 0; vid < cur_variable; ++vid) - { - const auto& val = d_assignment.get(d_variableOrdering[vid]); - le.add(d_variableOrdering[vid], val); - } - le.addFreeVariable(d_variableOrdering[cur_variable]); - } + prepareRootIsolation(le, cur_variable); for (const auto& c : d_constraints.getConstraints()) { const poly::Polynomial& p = std::get<0>(c); @@ -428,11 +419,17 @@ CACInterval CDCAC::intervalFromCharacterization( m.pushDownPolys(d, d_variableOrdering[cur_variable]); // Collect -oo, all roots, oo + + LazardEvaluation le; + prepareRootIsolation(le, cur_variable); std::vector<poly::Value> roots; roots.emplace_back(poly::Value::minus_infty()); for (const auto& p : m) { - auto tmp = isolate_real_roots(p, d_assignment); + Trace("cdcac") << "Isolating real roots of " << p << " over " + << d_assignment << std::endl; + + auto tmp = isolateRealRoots(le, p); roots.insert(roots.end(), tmp.begin(), tmp.end()); } roots.emplace_back(poly::Value::plus_infty()); @@ -464,6 +461,8 @@ CACInterval CDCAC::intervalFromCharacterization( d_assignment.set(d_variableOrdering[cur_variable], lower); for (const auto& p : m) { + Trace("cdcac") << "Evaluating " << p << " = 0 over " << d_assignment + << std::endl; if (evaluate_constraint(p, d_assignment, poly::SignCondition::EQ)) { l.add(p, true); @@ -477,6 +476,8 @@ CACInterval CDCAC::intervalFromCharacterization( d_assignment.set(d_variableOrdering[cur_variable], upper); for (const auto& p : m) { + Trace("cdcac") << "Evaluating " << p << " = 0 over " << d_assignment + << std::endl; if (evaluate_constraint(p, d_assignment, poly::SignCondition::EQ)) { u.add(p, true); @@ -570,8 +571,10 @@ std::vector<CACInterval> CDCAC::getUnsatCoverImpl(std::size_t curVariable, d_assignment.unset(d_variableOrdering[curVariable]); + Trace("cdcac") << "Building interval..." << std::endl; auto newInterval = intervalFromCharacterization(characterization, curVariable, sample); + Trace("cdcac") << "New interval: " << newInterval.d_interval << std::endl; newInterval.d_origins = collectConstraints(cov); intervals.emplace_back(newInterval); if (isProofEnabled()) @@ -730,6 +733,30 @@ void CDCAC::pruneRedundantIntervals(std::vector<CACInterval>& intervals) } } +void CDCAC::prepareRootIsolation(LazardEvaluation& le, + size_t cur_variable) const +{ + if (options().arith.nlCadLifting == options::NlCadLiftingMode::LAZARD) + { + for (size_t vid = 0; vid < cur_variable; ++vid) + { + const auto& val = d_assignment.get(d_variableOrdering[vid]); + le.add(d_variableOrdering[vid], val); + } + le.addFreeVariable(d_variableOrdering[cur_variable]); + } +} + +std::vector<poly::Value> CDCAC::isolateRealRoots( + LazardEvaluation& le, const poly::Polynomial& p) const +{ + if (options().arith.nlCadLifting == options::NlCadLiftingMode::LAZARD) + { + return le.isolateRealRoots(p); + } + return poly::isolate_real_roots(p, d_assignment); +} + } // namespace cad } // namespace nl } // namespace arith diff --git a/src/theory/arith/nl/cad/cdcac.h b/src/theory/arith/nl/cad/cdcac.h index 04b5cab24..8317c0813 100644 --- a/src/theory/arith/nl/cad/cdcac.h +++ b/src/theory/arith/nl/cad/cdcac.h @@ -29,6 +29,7 @@ #include "smt/env_obj.h" #include "theory/arith/nl/cad/cdcac_utils.h" #include "theory/arith/nl/cad/constraints.h" +#include "theory/arith/nl/cad/lazard_evaluation.h" #include "theory/arith/nl/cad/proof_generator.h" #include "theory/arith/nl/cad/variable_ordering.h" @@ -196,6 +197,20 @@ class CDCAC : protected EnvObj void pruneRedundantIntervals(std::vector<CACInterval>& intervals); /** + * Prepare the lazard evaluation object with the current assignment, if the + * lazard lifting is enabled. Otherwise, this function does nothing. + */ + void prepareRootIsolation(LazardEvaluation& le, size_t cur_variable) const; + + /** + * Isolates the real roots of the polynomial `p`. If the lazard lifting is + * enabled, this function uses `le.isolateRealRoots()`, otherwise uses the + * regular `poly::isolate_real_roots()`. + */ + std::vector<poly::Value> isolateRealRoots(LazardEvaluation& le, + const poly::Polynomial& p) const; + + /** * The current assignment. When the method terminates with SAT, it contains a * model for the input constraints. */ diff --git a/src/theory/arith/nl/cad/lazard_evaluation.cpp b/src/theory/arith/nl/cad/lazard_evaluation.cpp index aec0d46e3..032565d3d 100644 --- a/src/theory/arith/nl/cad/lazard_evaluation.cpp +++ b/src/theory/arith/nl/cad/lazard_evaluation.cpp @@ -821,22 +821,11 @@ std::vector<poly::Polynomial> LazardEvaluation::reducePolynomial( return {p}; } -/** - * Compute the infeasible regions of the given polynomial according to a sign - * condition. We first reduce the polynomial and isolate the real roots of every - * resulting polynomial. We store all roots (except for -infty, +infty and none) - * in a set. Then, we transform the set of roots into a list of infeasible - * regions by generating intervals between -infty and the first root, in between - * every two consecutive roots and between the last root and +infty. While doing - * this, we only keep those intervals that are actually infeasible for the - * original polynomial q over the partial assignment. Finally, we go over the - * intervals and aggregate consecutive intervals that connect. - */ -std::vector<poly::Interval> LazardEvaluation::infeasibleRegions( - const poly::Polynomial& q, poly::SignCondition sc) const +std::vector<poly::Value> LazardEvaluation::isolateRealRoots( + const poly::Polynomial& q) const { poly::Assignment a; - std::set<poly::Value> roots; + std::vector<poly::Value> roots; // reduce q to a set of reduced polynomials p for (const auto& p : reducePolynomial(q)) { @@ -849,9 +838,28 @@ std::vector<poly::Interval> LazardEvaluation::infeasibleRegions( if (poly::is_minus_infinity(r)) continue; if (poly::is_none(r)) continue; if (poly::is_plus_infinity(r)) continue; - roots.insert(r); + roots.emplace_back(r); } } + std::sort(roots.begin(), roots.end()); + return roots; +} + +/** + * Compute the infeasible regions of the given polynomial according to a sign + * condition. We first reduce the polynomial and isolate the real roots of every + * resulting polynomial. We store all roots (except for -infty, +infty and none) + * in a set. Then, we transform the set of roots into a list of infeasible + * regions by generating intervals between -infty and the first root, in between + * every two consecutive roots and between the last root and +infty. While doing + * this, we only keep those intervals that are actually infeasible for the + * original polynomial q over the partial assignment. Finally, we go over the + * intervals and aggregate consecutive intervals that connect. + */ +std::vector<poly::Interval> LazardEvaluation::infeasibleRegions( + const poly::Polynomial& q, poly::SignCondition sc) const +{ + std::vector<poly::Value> roots = isolateRealRoots(q); // generate all intervals // (-infty,root_0), [root_0], (root_0,root_1), ..., [root_m], (root_m,+infty) @@ -962,6 +970,16 @@ std::vector<poly::Polynomial> LazardEvaluation::reducePolynomial( { return {p}; } + +std::vector<poly::Value> LazardEvaluation::isolateRealRoots( + const poly::Polynomial& q) const +{ + WarningOnce() + << "CAD::LazardEvaluation is disabled because CoCoA is not available. " + "Falling back to regular real root isolation." + << std::endl; + return poly::isolate_real_roots(q, d_state->d_assignment); +} std::vector<poly::Interval> LazardEvaluation::infeasibleRegions( const poly::Polynomial& q, poly::SignCondition sc) const { diff --git a/src/theory/arith/nl/cad/lazard_evaluation.h b/src/theory/arith/nl/cad/lazard_evaluation.h index 3bb971c4c..2afccb462 100644 --- a/src/theory/arith/nl/cad/lazard_evaluation.h +++ b/src/theory/arith/nl/cad/lazard_evaluation.h @@ -94,6 +94,11 @@ class LazardEvaluation const poly::Polynomial& q) const; /** + * Isolates the real roots of the given polynomials. + */ + std::vector<poly::Value> isolateRealRoots(const poly::Polynomial& q) const; + + /** * Compute the infeasible regions of q under the given sign condition. * Uses reducePolynomial and then performs real root isolation on the * resulting polynomials to obtain the intervals. Mimics diff --git a/src/theory/arith/nl/cad_solver.cpp b/src/theory/arith/nl/cad_solver.cpp index 721308a3d..f4582ac20 100644 --- a/src/theory/arith/nl/cad_solver.cpp +++ b/src/theory/arith/nl/cad_solver.cpp @@ -16,12 +16,14 @@ #include "theory/arith/nl/cad_solver.h" #include "expr/skolem_manager.h" +#include "options/arith_options.h" #include "smt/env.h" #include "theory/arith/inference_manager.h" #include "theory/arith/nl/cad/cdcac.h" #include "theory/arith/nl/nl_model.h" #include "theory/arith/nl/poly_conversion.h" #include "theory/inference_id.h" +#include "theory/theory.h" namespace cvc5 { namespace theory { @@ -36,7 +38,8 @@ CadSolver::CadSolver(Env& env, InferenceManager& im, NlModel& model) #endif d_foundSatisfiability(false), d_im(im), - d_model(model) + d_model(model), + d_eqsubs(env) { NodeManager* nm = NodeManager::currentNM(); SkolemManager* sm = nm->getSkolemManager(); @@ -65,11 +68,41 @@ void CadSolver::initLastCall(const std::vector<Node>& assertions) Trace("nl-cad") << " " << a << std::endl; } } - // store or process assertions - d_CAC.reset(); - for (const Node& a : assertions) + if (options().arith.nlCadVarElim) { - d_CAC.getConstraints().addConstraint(a); + d_eqsubs.reset(); + std::vector<Node> processed = d_eqsubs.eliminateEqualities(assertions); + if (d_eqsubs.hasConflict()) + { + Node lem = NodeManager::currentNM()->mkAnd(d_eqsubs.getConflict()).negate(); + d_im.addPendingLemma(lem, InferenceId::ARITH_NL_CAD_CONFLICT, nullptr); + Trace("nl-cad") << "Found conflict: " << lem << std::endl; + return; + } + if (Trace.isOn("nl-cad")) + { + Trace("nl-cad") << "After simplifications" << std::endl; + Trace("nl-cad") << "* Assertions: " << std::endl; + for (const Node& a : processed) + { + Trace("nl-cad") << " " << a << std::endl; + } + } + d_CAC.reset(); + for (const Node& a : processed) + { + Assert(!a.isConst()); + d_CAC.getConstraints().addConstraint(a); + } + } + else + { + d_CAC.reset(); + for (const Node& a : assertions) + { + Assert(!a.isConst()); + d_CAC.getConstraints().addConstraint(a); + } } d_CAC.computeVariableOrdering(); d_CAC.retrieveInitialAssignment(d_model, d_ranVariable); @@ -84,6 +117,7 @@ void CadSolver::checkFull() { #ifdef CVC5_POLY_IMP if (d_CAC.getConstraints().getConstraints().empty()) { + d_foundSatisfiability = true; Trace("nl-cad") << "No constraints. Return." << std::endl; return; } @@ -101,6 +135,8 @@ void CadSolver::checkFull() Trace("nl-cad") << "Collected MIS: " << mis << std::endl; Assert(!mis.empty()) << "Infeasible subset can not be empty"; Trace("nl-cad") << "UNSAT with MIS: " << mis << std::endl; + d_eqsubs.postprocessConflict(mis); + Trace("nl-cad") << "After postprocessing: " << mis << std::endl; Node lem = NodeManager::currentNM()->mkAnd(mis).negate(); ProofGenerator* proof = d_CAC.closeProof(mis); d_im.addPendingLemma(lem, InferenceId::ARITH_NL_CAD_CONFLICT, proof); @@ -170,10 +206,15 @@ bool CadSolver::constructModelIfAvailable(std::vector<Node>& assertions) return false; } bool foundNonVariable = false; + for (const auto& sub: d_eqsubs.getSubstitutions()) + { + d_model.addSubstitution(sub.first, sub.second); + Trace("nl-cad") << "-> " << sub.first << " = " << sub.second << std::endl; + } for (const auto& v : d_CAC.getVariableOrdering()) { Node variable = d_CAC.getConstraints().varMapper()(v); - if (!variable.isVar()) + if (!Theory::isLeafOf(variable, TheoryId::THEORY_ARITH)) { Trace("nl-cad") << "Not a variable: " << variable << std::endl; foundNonVariable = true; diff --git a/src/theory/arith/nl/cad_solver.h b/src/theory/arith/nl/cad_solver.h index bedffcaa9..73d09378b 100644 --- a/src/theory/arith/nl/cad_solver.h +++ b/src/theory/arith/nl/cad_solver.h @@ -23,6 +23,7 @@ #include "smt/env_obj.h" #include "theory/arith/nl/cad/cdcac.h" #include "theory/arith/nl/cad/proof_checker.h" +#include "theory/arith/nl/equality_substitution.h" namespace cvc5 { @@ -104,6 +105,9 @@ class CadSolver: protected EnvObj InferenceManager& d_im; /** Reference to the non-linear model object */ NlModel& d_model; + /** Utility to eliminate variables from simple equalities before going into + * the actual coverings solver */ + EqualitySubstitution d_eqsubs; }; /* class CadSolver */ } // namespace nl diff --git a/src/theory/arith/nl/equality_substitution.cpp b/src/theory/arith/nl/equality_substitution.cpp new file mode 100644 index 000000000..9b3a79cd4 --- /dev/null +++ b/src/theory/arith/nl/equality_substitution.cpp @@ -0,0 +1,183 @@ +/****************************************************************************** + * Top contributors (to current version): + * Gereon Kremer, Andrew Reynolds, Andres Noetzli + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Implementation of new non-linear solver. + */ + +#include "theory/arith/nl/equality_substitution.h" + +#include "smt/env.h" + +namespace cvc5 { +namespace theory { +namespace arith { +namespace nl { + +EqualitySubstitution::EqualitySubstitution(Env& env) + : EnvObj(env), d_substitutions(std::make_unique<SubstitutionMap>()) +{ +} +void EqualitySubstitution::reset() +{ + d_substitutions = std::make_unique<SubstitutionMap>(); + d_conflict.clear(); + d_conflictMap.clear(); + d_trackOrigin.clear(); +} + +std::vector<Node> EqualitySubstitution::eliminateEqualities( + const std::vector<Node>& assertions) +{ + Trace("nl-eqs") << "Input:" << std::endl; + for (const auto& a : assertions) + { + Trace("nl-eqs") << "\t" << a << std::endl; + } + std::set<TNode> tracker; + std::vector<Node> asserts = assertions; + std::vector<Node> next; + + size_t last_size = 0; + while (asserts.size() != last_size) + { + last_size = asserts.size(); + // collect all eliminations from original into d_substitutions + for (const auto& orig : asserts) + { + if (orig.getKind() != Kind::EQUAL) continue; + tracker.clear(); + d_substitutions->invalidateCache(); + Node o = d_substitutions->apply(orig, d_env.getRewriter(), &tracker); + Trace("nl-eqs") << "Simplified for subst " << orig << " -> " << o + << std::endl; + if (o.getKind() != Kind::EQUAL) continue; + Assert(o.getNumChildren() == 2); + for (size_t i = 0; i < 2; ++i) + { + const auto& l = o[i]; + const auto& r = o[1 - i]; + if (l.isConst()) continue; + if (!Theory::isLeafOf(l, TheoryId::THEORY_ARITH)) continue; + if (d_substitutions->hasSubstitution(l)) continue; + if (expr::hasSubterm(r, l, true)) continue; + Trace("nl-eqs") << "Found substitution " << l << " -> " << r + << std::endl + << " from " << o << " / " << orig << std::endl; + d_substitutions->addSubstitution(l, r); + d_trackOrigin.emplace(l, o); + if (o != orig) + { + addToConflictMap(o, orig, tracker); + } + break; + } + } + + // simplify with subs from original into next + next.clear(); + for (const auto& a : asserts) + { + tracker.clear(); + d_substitutions->invalidateCache(); + Node simp = d_substitutions->apply(a, d_env.getRewriter(), &tracker); + if (simp.isConst()) + { + if (simp.getConst<bool>()) + { + continue; + } + Trace("nl-eqs") << "Simplified " << a << " to " << simp << std::endl; + for (TNode t : tracker) + { + Trace("nl-eqs") << "Tracker has " << t << std::endl; + auto toit = d_trackOrigin.find(t); + Assert(toit != d_trackOrigin.end()); + d_conflict.emplace_back(toit->second); + } + d_conflict.emplace_back(a); + postprocessConflict(d_conflict); + Trace("nl-eqs") << "Direct conflict: " << d_conflict << std::endl; + Trace("nl-eqs") << std::endl + << d_conflict.size() << " vs " + << std::distance(d_substitutions->begin(), + d_substitutions->end()) + << std::endl + << std::endl; + return {}; + } + if (simp != a) + { + Trace("nl-eqs") << "Simplified " << a << " to " << simp << std::endl; + addToConflictMap(simp, a, tracker); + } + next.emplace_back(simp); + } + asserts = std::move(next); + } + d_conflict.clear(); + return asserts; +} +void EqualitySubstitution::postprocessConflict( + std::vector<Node>& conflict) const +{ + Trace("nl-eqs") << "Postprocessing " << conflict << std::endl; + std::set<Node> result; + for (const auto& c : conflict) + { + auto it = d_conflictMap.find(c); + if (it == d_conflictMap.end()) + { + result.insert(c); + } + else + { + Trace("nl-eqs") << "Origin of " << c << ": " << it->second << std::endl; + result.insert(it->second.begin(), it->second.end()); + } + } + conflict.clear(); + conflict.insert(conflict.end(), result.begin(), result.end()); + Trace("nl-eqs") << "-> " << conflict << std::endl; +} +void EqualitySubstitution::insertOrigins(std::set<Node>& dest, + const Node& n) const +{ + auto it = d_conflictMap.find(n); + if (it == d_conflictMap.end()) + { + dest.insert(n); + } + else + { + dest.insert(it->second.begin(), it->second.end()); + } +} +void EqualitySubstitution::addToConflictMap(const Node& n, + const Node& orig, + const std::set<TNode>& tracker) +{ + std::set<Node> origins; + insertOrigins(origins, orig); + for (const auto& t : tracker) + { + auto tit = d_trackOrigin.find(t); + Assert(tit != d_trackOrigin.end()); + insertOrigins(origins, tit->second); + } + Trace("nl-eqs") << "ConflictMap: " << n << " -> " << origins << std::endl; + d_conflictMap.emplace(n, std::vector<Node>(origins.begin(), origins.end())); +} + +} // namespace nl +} // namespace arith +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/arith/nl/equality_substitution.h b/src/theory/arith/nl/equality_substitution.h new file mode 100644 index 000000000..b095af8df --- /dev/null +++ b/src/theory/arith/nl/equality_substitution.h @@ -0,0 +1,102 @@ +/****************************************************************************** + * Top contributors (to current version): + * Gereon Kremer + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * CAD-based solver based on https://arxiv.org/pdf/2003.05633.pdf. + */ + +#ifndef CVC5__THEORY__ARITH__NL__EQUALITY_SUBSTITUTION_H +#define CVC5__THEORY__ARITH__NL__EQUALITY_SUBSTITUTION_H + +#include <vector> + +#include "context/context.h" +#include "expr/node.h" +#include "expr/node_algorithm.h" +#include "smt/env_obj.h" +#include "theory/substitutions.h" +#include "theory/theory.h" + +namespace cvc5 { +namespace theory { +namespace arith { +namespace nl { + +/** + * This class is a general utility to eliminate variables from a set of + * assertions. + */ +class EqualitySubstitution : protected EnvObj +{ + public: + EqualitySubstitution(Env& env); + /** Reset this object */ + void reset(); + + /** + * Eliminate variables using equalities from the set of assertions. + * Returns a set of assertions where some variables may have been eliminated. + * Substitutions for the eliminated variables can be obtained from + * getSubstitutions(). + */ + std::vector<Node> eliminateEqualities(const std::vector<Node>& assertions); + /** + * Can be called after eliminateEqualities(). Returns the substitutions that + * were found and eliminated. + */ + const SubstitutionMap& getSubstitutions() const { return *d_substitutions; } + /** + * Can be called after eliminateEqualities(). Checks whether a direct conflict + * was found, that is an assertion simplified to false during + * eliminateEqualities(). + */ + bool hasConflict() const { return !d_conflict.empty(); } + /** + * Return the conflict found in eliminateEqualities() as a set of assertions + * that is a subset of the input assertions provided to eliminateEqualities(). + */ + const std::vector<Node>& getConflict() const { return d_conflict; } + /** + * Postprocess a conflict found in the result of eliminateEqualities. + * Replaces assertions within the conflict by their origins, i.e. the input + * assertions and the assertions that gave rise to the substitutions being + * used. + */ + void postprocessConflict(std::vector<Node>& conflict) const; + + private: + /** Utility method for addToConflictMap. Checks for n in d_conflictMap */ + void insertOrigins(std::set<Node>& dest, const Node& n) const; + /** Add n -> { orig, *tracker } to the conflict map. The tracked nodes are + * first resolved using d_trackOrigin, and everything is run through + * insertOrigins to make sure that all origins are input assertions. */ + void addToConflictMap(const Node& n, + const Node& orig, + const std::set<TNode>& tracker); + + // The SubstitutionMap + std::unique_ptr<SubstitutionMap> d_substitutions; + // conflicting assertions, if a conflict was found + std::vector<Node> d_conflict; + // Maps a simplified assertion to the original assertion + set of original + // assertions used for substitutions + std::map<Node, std::vector<Node>> d_conflictMap; + // Maps substituted terms (what will end up in the tracker) to the equality + // from which the substitution was derived. + std::map<Node, Node> d_trackOrigin; +}; + +} // namespace nl +} // namespace arith +} // namespace theory +} // namespace cvc5 + +#endif /* CVC5__THEORY__ARITH__NL__EQUALITY_SUBSTITUTION_H */ diff --git a/src/theory/arith/nl/ext/factoring_check.cpp b/src/theory/arith/nl/ext/factoring_check.cpp index 06d6aeaab..32b630fa8 100644 --- a/src/theory/arith/nl/ext/factoring_check.cpp +++ b/src/theory/arith/nl/ext/factoring_check.cpp @@ -35,7 +35,6 @@ namespace nl { FactoringCheck::FactoringCheck(Env& env, ExtState* data) : EnvObj(env), d_data(data) { - d_zero = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(0)); d_one = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(1)); } @@ -155,7 +154,8 @@ void FactoringCheck::check(const std::vector<Node>& asserts, poly.size() == 1 ? poly[0] : nm->mkNode(Kind::PLUS, poly); Trace("nl-ext-factor") << "...factored polynomial : " << polyn << std::endl; - Node conc_lit = nm->mkNode(atom.getKind(), polyn, d_zero); + Node zero = nm->mkConstRealOrInt(polyn.getType(), Rational(0)); + Node conc_lit = nm->mkNode(atom.getKind(), polyn, zero); conc_lit = rewrite(conc_lit); if (!polarity) { diff --git a/src/theory/arith/nl/ext/monomial.cpp b/src/theory/arith/nl/ext/monomial.cpp index 83d0ff71f..47beb6959 100644 --- a/src/theory/arith/nl/ext/monomial.cpp +++ b/src/theory/arith/nl/ext/monomial.cpp @@ -326,7 +326,6 @@ Node MonomialDb::mkMonomialRemFactor(Node n, children.insert(children.end(), inc, v); } Node ret = safeConstructNary(MULT, children); - ret = Rewriter::rewrite(ret); Trace("nl-ext-mono-factor") << "...return : " << ret << std::endl; return ret; } diff --git a/src/theory/arith/nl/iand_solver.cpp b/src/theory/arith/nl/iand_solver.cpp index 5d4862307..c661dab4b 100644 --- a/src/theory/arith/nl/iand_solver.cpp +++ b/src/theory/arith/nl/iand_solver.cpp @@ -47,9 +47,9 @@ IAndSolver::IAndSolver(Env& env, NodeManager* nm = NodeManager::currentNM(); d_false = nm->mkConst(false); d_true = nm->mkConst(true); - d_zero = nm->mkConst(CONST_RATIONAL, Rational(0)); - d_one = nm->mkConst(CONST_RATIONAL, Rational(1)); - d_two = nm->mkConst(CONST_RATIONAL, Rational(2)); + d_zero = nm->mkConstInt(Rational(0)); + d_one = nm->mkConstInt(Rational(1)); + d_two = nm->mkConstInt(Rational(2)); } IAndSolver::~IAndSolver() {} @@ -100,7 +100,7 @@ void IAndSolver::checkInitialRefine() // conj.push_back(i.eqNode(nm->mkNode(IAND, op, i[1], i[0]))); // 0 <= iand(x,y) < 2^k conj.push_back(nm->mkNode(LEQ, d_zero, i)); - conj.push_back(nm->mkNode(LT, i, d_iandUtils.twoToK(k))); + conj.push_back(nm->mkNode(LT, i, rewrite(d_iandUtils.twoToK(k)))); // iand(x,y)<=x conj.push_back(nm->mkNode(LEQ, i, i[0])); // iand(x,y)<=y @@ -280,8 +280,8 @@ Node IAndSolver::bitwiseLemma(Node i) // compare each bit to bvI Node cond; Node bitIAnd; - unsigned high_bit; - for (unsigned j = 0; j < bvsize; j += granularity) + uint64_t high_bit; + for (uint64_t j = 0; j < bvsize; j += granularity) { high_bit = j + granularity - 1; // don't let high_bit pass bvsize @@ -296,7 +296,9 @@ Node IAndSolver::bitwiseLemma(Node i) bitIAnd = d_iandUtils.createBitwiseIAndNode(x, y, high_bit, j); // enforce bitwise equality lem = nm->mkNode( - AND, lem, d_iandUtils.iextract(high_bit, j, i).eqNode(bitIAnd)); + AND, + lem, + rewrite(d_iandUtils.iextract(high_bit, j, i)).eqNode(bitIAnd)); } } return lem; diff --git a/src/theory/arith/nl/iand_solver.h b/src/theory/arith/nl/iand_solver.h index 0b6a1fac6..997112fee 100644 --- a/src/theory/arith/nl/iand_solver.h +++ b/src/theory/arith/nl/iand_solver.h @@ -103,7 +103,7 @@ class IAndSolver : protected EnvObj /** * convert integer value to bitvector value of bitwidth k, - * equivalent to Rewriter::rewrite( ((_ intToBv k) n) ). + * equivalent to rewrite( ((_ intToBv k) n) ). */ Node convertToBvK(unsigned k, Node n) const; /** make iand */ @@ -115,7 +115,7 @@ class IAndSolver : protected EnvObj /** * Value-based refinement lemma for i of the form ((_ iand k) x y). Returns: * x = M(x) ^ y = M(y) => - * ((_ iand k) x y) = Rewriter::rewrite(((_ iand k) M(x) M(y))) + * ((_ iand k) x y) = rewrite(((_ iand k) M(x) M(y))) */ Node valueBasedLemma(Node i); /** diff --git a/src/theory/arith/nl/iand_utils.cpp b/src/theory/arith/nl/iand_utils.cpp index 50e03bfa5..700eb6de9 100644 --- a/src/theory/arith/nl/iand_utils.cpp +++ b/src/theory/arith/nl/iand_utils.cpp @@ -38,7 +38,7 @@ Node pow2(uint64_t k) { Assert(k >= 0); NodeManager* nm = NodeManager::currentNM(); - return nm->mkConst(CONST_RATIONAL, Rational(intpow2(k))); + return nm->mkConstInt(Rational(intpow2(k))); } bool oneBitAnd(bool a, bool b) { return (a && b); } @@ -60,9 +60,9 @@ Node intExtract(Node x, uint64_t i, uint64_t size) IAndUtils::IAndUtils() { NodeManager* nm = NodeManager::currentNM(); - d_zero = nm->mkConst(CONST_RATIONAL, Rational(0)); - d_one = nm->mkConst(CONST_RATIONAL, Rational(1)); - d_two = nm->mkConst(CONST_RATIONAL, Rational(2)); + d_zero = nm->mkConstInt(Rational(0)); + d_one = nm->mkConstInt(Rational(1)); + d_two = nm->mkConstInt(Rational(2)); } Node IAndUtils::createITEFromTable( @@ -80,8 +80,7 @@ Node IAndUtils::createITEFromTable( Assert(table.size() == 1 + ((uint64_t)(num_of_values * num_of_values))); // start with the default, most common value. // this value is represented in the table by (-1, -1). - Node ite = - nm->mkConst(CONST_RATIONAL, Rational(table.at(std::make_pair(-1, -1)))); + Node ite = nm->mkConstInt(Rational(table.at(std::make_pair(-1, -1)))); for (uint64_t i = 0; i < num_of_values; i++) { for (uint64_t j = 0; j < num_of_values; j++) @@ -94,13 +93,10 @@ Node IAndUtils::createITEFromTable( // append the current value to the ite. ite = nm->mkNode( kind::ITE, - nm->mkNode( - kind::AND, - nm->mkNode( - kind::EQUAL, x, nm->mkConst(CONST_RATIONAL, Rational(i))), - nm->mkNode( - kind::EQUAL, y, nm->mkConst(CONST_RATIONAL, Rational(j)))), - nm->mkConst(CONST_RATIONAL, Rational(table.at(std::make_pair(i, j)))), + nm->mkNode(kind::AND, + nm->mkNode(kind::EQUAL, x, nm->mkConstInt(Rational(i))), + nm->mkNode(kind::EQUAL, y, nm->mkConstInt(Rational(j)))), + nm->mkConstInt(Rational(table.at(std::make_pair(i, j)))), ite); } } @@ -139,7 +135,7 @@ Node IAndUtils::createSumNode(Node x, // number of elements in the sum expression uint64_t sumSize = bvsize / granularity; // initialize the sum - Node sumNode = nm->mkConst(CONST_RATIONAL, Rational(0)); + Node sumNode = nm->mkConstInt(Rational(0)); // compute the table for the current granularity if needed if (d_bvandTable.find(granularity) == d_bvandTable.end()) { @@ -186,9 +182,7 @@ Node IAndUtils::iextract(unsigned i, unsigned j, Node n) const NodeManager* nm = NodeManager::currentNM(); // ((_ extract i j) n) is n / 2^j mod 2^{i-j+1} Node n2j = nm->mkNode(kind::INTS_DIVISION_TOTAL, n, twoToK(j)); - Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, n2j, twoToK(i - j + 1)); - ret = Rewriter::rewrite(ret); - return ret; + return nm->mkNode(kind::INTS_MODULUS_TOTAL, n2j, twoToK(i - j + 1)); } void IAndUtils::computeAndTable(uint64_t granularity) @@ -266,19 +260,14 @@ Node IAndUtils::twoToK(unsigned k) const { // could be faster NodeManager* nm = NodeManager::currentNM(); - Node ret = - nm->mkNode(kind::POW, d_two, nm->mkConst(CONST_RATIONAL, Rational(k))); - ret = Rewriter::rewrite(ret); - return ret; + return nm->mkNode(kind::POW, d_two, nm->mkConstInt(Rational(k))); } Node IAndUtils::twoToKMinusOne(unsigned k) const { // could be faster NodeManager* nm = NodeManager::currentNM(); - Node ret = nm->mkNode(kind::MINUS, twoToK(k), d_one); - ret = Rewriter::rewrite(ret); - return ret; + return nm->mkNode(kind::MINUS, twoToK(k), d_one); } } // namespace nl diff --git a/src/theory/arith/nl/icp/icp_solver.cpp b/src/theory/arith/nl/icp/icp_solver.cpp index 92c7d3ddd..aab63325e 100644 --- a/src/theory/arith/nl/icp/icp_solver.cpp +++ b/src/theory/arith/nl/icp/icp_solver.cpp @@ -66,7 +66,7 @@ inline std::ostream& operator<<(std::ostream& os, const IAWrapper& iaw) } // namespace ICPSolver::ICPSolver(Env& env, InferenceManager& im) - : EnvObj(env), d_im(im), d_state(d_mapper) + : EnvObj(env), d_im(im), d_state(env, d_mapper) { } diff --git a/src/theory/arith/nl/icp/icp_solver.h b/src/theory/arith/nl/icp/icp_solver.h index 8b0fbf583..b849255cc 100644 --- a/src/theory/arith/nl/icp/icp_solver.h +++ b/src/theory/arith/nl/icp/icp_solver.h @@ -86,12 +86,12 @@ class ICPSolver : protected EnvObj std::vector<Node> d_conflict; /** Initialized the variable bounds with a variable mapper */ - ICPState(VariableMapper& vm) {} + ICPState(Env& env, VariableMapper& vm) : d_bounds(env) {} /** Reset this state */ void reset() { - d_bounds = BoundInference(); + d_bounds.reset(); d_candidates.clear(); d_assignment.clear(); d_origins = ContractionOriginManager(); diff --git a/src/theory/arith/nl/nl_model.cpp b/src/theory/arith/nl/nl_model.cpp index d23ddd53d..90138bf3e 100644 --- a/src/theory/arith/nl/nl_model.cpp +++ b/src/theory/arith/nl/nl_model.cpp @@ -32,7 +32,7 @@ namespace theory { namespace arith { namespace nl { -NlModel::NlModel() : d_used_approx(false) +NlModel::NlModel(Env& env) : EnvObj(env), d_used_approx(false) { d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); @@ -122,7 +122,7 @@ Node NlModel::computeModelValue(TNode n, bool isConcrete) children.emplace_back(computeModelValue(n[i], isConcrete)); } ret = NodeManager::currentNM()->mkNode(n.getKind(), children); - ret = Rewriter::rewrite(ret); + ret = rewrite(ret); } } Trace("nl-ext-mv-debug") << "computed " << (isConcrete ? "M" : "M_A") << "[" @@ -246,7 +246,7 @@ bool NlModel::checkModel(const std::vector<Node>& assertions, // apply the substitution to a if (!d_substitutions.empty()) { - av = Rewriter::rewrite(arithSubstitute(av, d_substitutions)); + av = rewrite(arithSubstitute(av, d_substitutions)); } // simple check literal if (!simpleCheckModelLit(av)) @@ -307,7 +307,7 @@ bool NlModel::addSubstitution(TNode v, TNode s) Node ms = arithSubstitute(sub, tmp); if (ms != sub) { - sub = Rewriter::rewrite(ms); + sub = rewrite(ms); } } d_substitutions.add(v, s); @@ -376,7 +376,7 @@ bool NlModel::solveEqualitySimple(Node eq, if (!d_substitutions.empty()) { seq = arithSubstitute(eq, d_substitutions); - seq = Rewriter::rewrite(seq); + seq = rewrite(seq); if (seq.isConst()) { if (seq.getConst<bool>()) @@ -580,7 +580,7 @@ bool NlModel::simpleCheckModelLit(Node lit) { lit2 = lit2.negate(); } - lit2 = Rewriter::rewrite(lit2); + lit2 = rewrite(lit2); bool success = simpleCheckModelLit(lit2); if (success != pol) { @@ -669,7 +669,7 @@ bool NlModel::simpleCheckModelLit(Node lit) b = it->second; t = nm->mkNode(PLUS, t, nm->mkNode(MULT, b, v)); } - t = Rewriter::rewrite(t); + t = rewrite(t); Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic " << t << "..." << std::endl; Trace("nl-ext-cms-debug") << " a = " << a << std::endl; @@ -677,7 +677,7 @@ bool NlModel::simpleCheckModelLit(Node lit) // find maximal/minimal value on the interval Node apex = nm->mkNode( DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a)); - apex = Rewriter::rewrite(apex); + apex = rewrite(apex); Assert(apex.isConst()); // for lower, upper, whether we are greater than the apex bool cmp[2]; @@ -686,7 +686,7 @@ bool NlModel::simpleCheckModelLit(Node lit) { boundn[r] = r == 0 ? bit->second.first : bit->second.second; Node cmpn = nm->mkNode(GT, boundn[r], apex); - cmpn = Rewriter::rewrite(cmpn); + cmpn = rewrite(cmpn); Assert(cmpn.isConst()); cmp[r] = cmpn.getConst<bool>(); } @@ -717,12 +717,12 @@ bool NlModel::simpleCheckModelLit(Node lit) { qsub.d_subs.back() = boundn[r]; Node ts = arithSubstitute(t, qsub); - tcmpn[r] = Rewriter::rewrite(ts); + tcmpn[r] = rewrite(ts); } Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]); Trace("nl-ext-cms-debug") << " ...both sides of apex, compare " << tcmp << std::endl; - tcmp = Rewriter::rewrite(tcmp); + tcmp = rewrite(tcmp); Assert(tcmp.isConst()); unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0; Trace("nl-ext-cms-debug") @@ -756,7 +756,7 @@ bool NlModel::simpleCheckModelLit(Node lit) if (!qsub.empty()) { Node slit = arithSubstitute(lit, qsub); - slit = Rewriter::rewrite(slit); + slit = rewrite(slit); return simpleCheckModelLit(slit); } return false; @@ -1003,7 +1003,7 @@ bool NlModel::simpleCheckModelMsum(const std::map<Node, Node>& msum, bool pol) comp = comp.negate(); } Trace("nl-ext-cms") << " comparison is : " << comp << std::endl; - comp = Rewriter::rewrite(comp); + comp = rewrite(comp); Assert(comp.isConst()); Trace("nl-ext-cms") << " returned : " << comp << std::endl; return comp == d_true; @@ -1073,7 +1073,7 @@ void NlModel::getModelValueRepair( witness = nm->mkNode(MULT, nm->mkConst(CONST_RATIONAL, Rational(1, 2)), nm->mkNode(PLUS, l, u)); - witness = Rewriter::rewrite(witness); + witness = rewrite(witness); Trace("nl-model") << v << " witness is " << witness << std::endl; } approximations[v] = std::pair<Node, Node>(pred, witness); diff --git a/src/theory/arith/nl/nl_model.h b/src/theory/arith/nl/nl_model.h index 7dcd89a4a..e195aa9b2 100644 --- a/src/theory/arith/nl/nl_model.h +++ b/src/theory/arith/nl/nl_model.h @@ -23,6 +23,7 @@ #include "expr/kind.h" #include "expr/node.h" #include "expr/subs.h" +#include "smt/env_obj.h" namespace cvc5 { @@ -48,12 +49,12 @@ class NonlinearExtension; * model in the case it can determine that a model exists. These include * techniques based on solving (quadratic) equations and bound analysis. */ -class NlModel +class NlModel : protected EnvObj { friend class NonlinearExtension; public: - NlModel(); + NlModel(Env& env); ~NlModel(); /** * This method is called once at the beginning of a last call effort check, diff --git a/src/theory/arith/nl/nonlinear_extension.cpp b/src/theory/arith/nl/nonlinear_extension.cpp index e75741096..3f60f8596 100644 --- a/src/theory/arith/nl/nonlinear_extension.cpp +++ b/src/theory/arith/nl/nonlinear_extension.cpp @@ -48,7 +48,7 @@ NonlinearExtension::NonlinearExtension(Env& env, d_checkCounter(0), d_extTheoryCb(state.getEqualityEngine()), d_extTheory(env, d_extTheoryCb, d_im), - d_model(), + d_model(env), d_trSlv(d_env, d_im, d_model), d_extState(d_im, d_model, d_env), d_factoringSlv(d_env, &d_extState), @@ -122,7 +122,7 @@ void NonlinearExtension::getAssertions(std::vector<Node>& assertions) } Valuation v = d_containing.getValuation(); - BoundInference bounds; + BoundInference bounds(d_env); std::unordered_set<Node> init_assertions; @@ -353,45 +353,6 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set<Node>& termS } // compute whether shared terms have correct values - unsigned num_shared_wrong_value = 0; - std::vector<Node> shared_term_value_splits; - // must ensure that shared terms are equal to their concrete value - Trace("nl-ext-mv") << "Shared terms : " << std::endl; - for (context::CDList<TNode>::const_iterator its = - d_containing.shared_terms_begin(); - its != d_containing.shared_terms_end(); - ++its) - { - TNode shared_term = *its; - // compute its value in the model, and its evaluation in the model - Node stv0 = d_model.computeConcreteModelValue(shared_term); - Node stv1 = d_model.computeAbstractModelValue(shared_term); - d_model.printModelValue("nl-ext-mv", shared_term); - if (stv0 != stv1) - { - num_shared_wrong_value++; - Trace("nl-ext-mv") << "Bad shared term value : " << shared_term - << std::endl; - if (shared_term != stv0) - { - // split on the value, this is non-terminating in general, TODO : - // improve this - Node eq = shared_term.eqNode(stv0); - shared_term_value_splits.push_back(eq); - } - else - { - // this can happen for transcendental functions - // the problem is that we cannot evaluate transcendental functions - // (they don't have a rewriter that returns constants) - // thus, the actual value in their model can be themselves, hence we - // have no reference point to rule out the current model. In this - // case, we may set incomplete below. - } - } - } - Trace("nl-ext-debug") << " " << num_shared_wrong_value - << " shared terms with wrong model value." << std::endl; bool needsRecheck; do { @@ -402,9 +363,9 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set<Node>& termS int complete_status = 1; // We require a check either if an assertion is false or a shared term has // a wrong value - if (!false_asserts.empty() || num_shared_wrong_value > 0) + if (!false_asserts.empty()) { - complete_status = num_shared_wrong_value > 0 ? -1 : 0; + complete_status = 0; runStrategy(Theory::Effort::EFFORT_FULL, assertions, false_asserts, xts); if (d_im.hasSentLemma() || d_im.hasPendingLemma()) { @@ -446,40 +407,6 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set<Node>& termS << std::endl; return Result::Sat::UNSAT; } - // resort to splitting on shared terms with their model value - // if we did not add any lemmas - if (num_shared_wrong_value > 0) - { - complete_status = -1; - if (!shared_term_value_splits.empty()) - { - for (const Node& eq : shared_term_value_splits) - { - Node req = rewrite(eq); - Node literal = d_containing.getValuation().ensureLiteral(req); - d_containing.getOutputChannel().requirePhase(literal, true); - Trace("nl-ext-debug") << "Split on : " << literal << std::endl; - Node split = literal.orNode(literal.negate()); - d_im.addPendingLemma(split, - InferenceId::ARITH_NL_SHARED_TERM_VALUE_SPLIT, - nullptr, - true); - } - if (d_im.hasWaitingLemma()) - { - d_im.flushWaitingLemmas(); - Trace("nl-ext") << "...added " << d_im.numPendingLemmas() - << " shared term value split lemmas." << std::endl; - return Result::Sat::UNSAT; - } - } - else - { - // this can happen if we are trying to do theory combination with - // trancendental functions - // since their model value cannot even be computed exactly - } - } // we are incomplete if (options().arith.nlExt == options::NlExtMode::FULL diff --git a/src/theory/arith/nl/pow2_solver.cpp b/src/theory/arith/nl/pow2_solver.cpp index e3a26397e..59bf89151 100644 --- a/src/theory/arith/nl/pow2_solver.cpp +++ b/src/theory/arith/nl/pow2_solver.cpp @@ -42,9 +42,9 @@ Pow2Solver::Pow2Solver(Env& env, NodeManager* nm = NodeManager::currentNM(); d_false = nm->mkConst(false); d_true = nm->mkConst(true); - d_zero = nm->mkConst(CONST_RATIONAL, Rational(0)); - d_one = nm->mkConst(CONST_RATIONAL, Rational(1)); - d_two = nm->mkConst(CONST_RATIONAL, Rational(2)); + d_zero = nm->mkConstInt(Rational(0)); + d_one = nm->mkConstInt(Rational(1)); + d_two = nm->mkConstInt(Rational(2)); } Pow2Solver::~Pow2Solver() {} @@ -190,8 +190,7 @@ Node Pow2Solver::valueBasedLemma(Node i) Node valC = nm->mkNode(POW2, valX); valC = rewrite(valC); - Node lem = nm->mkNode(IMPLIES, x.eqNode(valX), i.eqNode(valC)); - return lem; + return nm->mkNode(IMPLIES, x.eqNode(valX), i.eqNode(valC)); } } // namespace nl diff --git a/src/theory/arith/nl/pow2_solver.h b/src/theory/arith/nl/pow2_solver.h index b4e12616c..42586f206 100644 --- a/src/theory/arith/nl/pow2_solver.h +++ b/src/theory/arith/nl/pow2_solver.h @@ -100,7 +100,7 @@ class Pow2Solver : protected EnvObj /** * Value-based refinement lemma for i of the form (pow2 x). Returns: * x = M(x) /\ x>= 0 ----> - * (pow2 x) = Rewriter::rewrite((pow2 M(x))) + * (pow2 x) = rewrite((pow2 M(x))) */ Node valueBasedLemma(Node i); }; /* class Pow2Solver */ diff --git a/src/theory/arith/nl/strategy.cpp b/src/theory/arith/nl/strategy.cpp index b33e45129..a14841f67 100644 --- a/src/theory/arith/nl/strategy.cpp +++ b/src/theory/arith/nl/strategy.cpp @@ -172,10 +172,7 @@ void Strategy::initializeStrategy(const Options& options) one << InferStep::POW2_FULL << InferStep::BREAK; if (options.arith.nlCad) { - one << InferStep::CAD_INIT; - } - if (options.arith.nlCad) - { + one << InferStep::CAD_INIT << InferStep::BREAK; one << InferStep::CAD_FULL << InferStep::BREAK; } diff --git a/src/theory/arith/nl/transcendental/exponential_solver.cpp b/src/theory/arith/nl/transcendental/exponential_solver.cpp index c4f7f6ca9..77e5f9f3f 100644 --- a/src/theory/arith/nl/transcendental/exponential_solver.cpp +++ b/src/theory/arith/nl/transcendental/exponential_solver.cpp @@ -230,7 +230,7 @@ void ExponentialSolver::doTangentLemma(TNode e, proof->addStep(lem, PfRule::ARITH_TRANS_EXP_APPROX_BELOW, {}, - {nm->mkConst(CONST_RATIONAL, Rational(d)), e[0]}); + {nm->mkConstInt(Rational(d)), e[0]}); } d_data->d_im.addPendingLemma( lem, InferenceId::ARITH_NL_T_TANGENT, proof, true); diff --git a/src/theory/arith/nl/transcendental/proof_checker.cpp b/src/theory/arith/nl/transcendental/proof_checker.cpp index ca1afb9f6..3bf1ace98 100644 --- a/src/theory/arith/nl/transcendental/proof_checker.cpp +++ b/src/theory/arith/nl/transcendental/proof_checker.cpp @@ -18,7 +18,7 @@ #include "expr/sequence.h" #include "theory/arith/arith_utilities.h" #include "theory/arith/nl/transcendental/taylor_generator.h" -#include "theory/rewriter.h" +#include "theory/evaluator.h" using namespace cvc5::kind; @@ -42,18 +42,18 @@ Node mkBounds(TNode t, TNode lb, TNode ub) /** * Helper method to construct a secant plane: - * ((evall - evalu) / (l - u)) * (t - l) + evall + * evall + ((evall - evalu) / (l - u)) * (t - l) */ Node mkSecant(TNode t, TNode l, TNode u, TNode evall, TNode evalu) { NodeManager* nm = NodeManager::currentNM(); return nm->mkNode(Kind::PLUS, + evall, nm->mkNode(Kind::MULT, nm->mkNode(Kind::DIVISION, nm->mkNode(Kind::MINUS, evall, evalu), nm->mkNode(Kind::MINUS, l, u)), - nm->mkNode(Kind::MINUS, t, l)), - evall); + nm->mkNode(Kind::MINUS, t, l))); } } // namespace @@ -83,11 +83,11 @@ Node TranscendentalProofRuleChecker::checkInternal( PfRule id, const std::vector<Node>& children, const std::vector<Node>& args) { NodeManager* nm = NodeManager::currentNM(); - auto zero = nm->mkConst<Rational>(CONST_RATIONAL, 0); - auto one = nm->mkConst<Rational>(CONST_RATIONAL, 1); - auto mone = nm->mkConst<Rational>(CONST_RATIONAL, -1); - auto pi = nm->mkNullaryOperator(nm->realType(), Kind::PI); - auto mpi = nm->mkNode(Kind::MULT, mone, pi); + Node zero = nm->mkConstReal(Rational(0)); + Node one = nm->mkConstReal(Rational(1)); + Node mone = nm->mkConstReal(Rational(-1)); + Node pi = nm->mkNullaryOperator(nm->realType(), Kind::PI); + Node mpi = nm->mkNode(Kind::MULT, mone, pi); Trace("nl-trans-checker") << "Checking " << id << std::endl; Trace("nl-trans-checker") << "Children:" << std::endl; for (const auto& c : children) @@ -141,11 +141,10 @@ Node TranscendentalProofRuleChecker::checkInternal( { Assert(children.empty()); Assert(args.size() == 4); - Assert(args[0].isConst() && args[0].getKind() == Kind::CONST_RATIONAL - && args[0].getConst<Rational>().isIntegral()); + Assert(args[0].isConst() && args[0].getType().isInteger()); Assert(args[1].getType().isReal()); - Assert(args[2].isConst() && args[2].getKind() == Kind::CONST_RATIONAL); - Assert(args[3].isConst() && args[3].getKind() == Kind::CONST_RATIONAL); + Assert(args[2].isConst() && args[2].getType().isRealOrInt()); + Assert(args[3].isConst() && args[3].getType().isRealOrInt()); std::uint64_t d = args[0].getConst<Rational>().getNumerator().toUnsignedInt(); Node t = args[1]; @@ -154,26 +153,24 @@ Node TranscendentalProofRuleChecker::checkInternal( TaylorGenerator tg; TaylorGenerator::ApproximationBounds bounds; tg.getPolynomialApproximationBounds(Kind::EXPONENTIAL, d / 2, bounds); - Node evall = Rewriter::rewrite( - bounds.d_upperPos.substitute(tg.getTaylorVariable(), l)); - Node evalu = Rewriter::rewrite( - bounds.d_upperPos.substitute(tg.getTaylorVariable(), u)); + Evaluator eval(nullptr); + Node evall = eval.eval(bounds.d_upperPos, {tg.getTaylorVariable()}, {l}); + Node evalu = eval.eval(bounds.d_upperPos, {tg.getTaylorVariable()}, {u}); Node evalsecant = mkSecant(t, l, u, evall, evalu); Node lem = nm->mkNode( Kind::IMPLIES, mkBounds(t, l, u), nm->mkNode(Kind::LEQ, nm->mkNode(Kind::EXPONENTIAL, t), evalsecant)); - return Rewriter::rewrite(lem); + return lem; } else if (id == PfRule::ARITH_TRANS_EXP_APPROX_ABOVE_NEG) { Assert(children.empty()); Assert(args.size() == 4); - Assert(args[0].isConst() && args[0].getKind() == Kind::CONST_RATIONAL - && args[0].getConst<Rational>().isIntegral()); + Assert(args[0].isConst() && args[0].getType().isInteger()); Assert(args[1].getType().isReal()); - Assert(args[2].isConst() && args[2].getKind() == Kind::CONST_RATIONAL); - Assert(args[3].isConst() && args[3].getKind() == Kind::CONST_RATIONAL); + Assert(args[2].isConst() && args[2].getType().isRealOrInt()); + Assert(args[3].isConst() && args[3].getType().isRealOrInt()); std::uint64_t d = args[0].getConst<Rational>().getNumerator().toUnsignedInt(); Node t = args[1]; @@ -182,23 +179,21 @@ Node TranscendentalProofRuleChecker::checkInternal( TaylorGenerator tg; TaylorGenerator::ApproximationBounds bounds; tg.getPolynomialApproximationBounds(Kind::EXPONENTIAL, d / 2, bounds); - Node evall = Rewriter::rewrite( - bounds.d_upperNeg.substitute(tg.getTaylorVariable(), l)); - Node evalu = Rewriter::rewrite( - bounds.d_upperNeg.substitute(tg.getTaylorVariable(), u)); + Evaluator eval(nullptr); + Node evall = eval.eval(bounds.d_upperNeg, {tg.getTaylorVariable()}, {l}); + Node evalu = eval.eval(bounds.d_upperNeg, {tg.getTaylorVariable()}, {u}); Node evalsecant = mkSecant(t, l, u, evall, evalu); Node lem = nm->mkNode( Kind::IMPLIES, mkBounds(t, l, u), nm->mkNode(Kind::LEQ, nm->mkNode(Kind::EXPONENTIAL, t), evalsecant)); - return Rewriter::rewrite(lem); + return lem; } else if (id == PfRule::ARITH_TRANS_EXP_APPROX_BELOW) { Assert(children.empty()); Assert(args.size() == 2); - Assert(args[0].isConst() && args[0].getKind() == Kind::CONST_RATIONAL - && args[0].getConst<Rational>().isIntegral()); + Assert(args[0].isConst() && args[0].getType().isInteger()); Assert(args[1].getType().isReal()); std::uint64_t d = args[0].getConst<Rational>().getNumerator().toUnsignedInt(); @@ -206,10 +201,10 @@ Node TranscendentalProofRuleChecker::checkInternal( TaylorGenerator tg; TaylorGenerator::ApproximationBounds bounds; tg.getPolynomialApproximationBounds(Kind::EXPONENTIAL, d, bounds); - Node eval = - Rewriter::rewrite(bounds.d_lower.substitute(tg.getTaylorVariable(), t)); + Evaluator eval(nullptr); + Node evalt = eval.eval(bounds.d_lower, {tg.getTaylorVariable()}, {t}); return nm->mkNode( - Kind::GEQ, std::vector<Node>{nm->mkNode(Kind::EXPONENTIAL, t), eval}); + Kind::GEQ, std::vector<Node>{nm->mkNode(Kind::EXPONENTIAL, t), evalt}); } else if (id == PfRule::ARITH_TRANS_SINE_BOUNDS) { @@ -240,10 +235,7 @@ Node TranscendentalProofRuleChecker::checkInternal( x.eqNode( nm->mkNode(Kind::PLUS, y, - nm->mkNode(Kind::MULT, - nm->mkConst<Rational>(CONST_RATIONAL, 2), - s, - pi)))), + nm->mkNode(Kind::MULT, nm->mkConstReal(2), s, pi)))), nm->mkNode(Kind::SINE, y).eqNode(nm->mkNode(Kind::SINE, x))}); } else if (id == PfRule::ARITH_TRANS_SINE_SYMMETRY) @@ -252,8 +244,7 @@ Node TranscendentalProofRuleChecker::checkInternal( Assert(args.size() == 1); Assert(args[0].getType().isReal()); Node s1 = nm->mkNode(Kind::SINE, args[0]); - Node s2 = nm->mkNode( - Kind::SINE, Rewriter::rewrite(nm->mkNode(Kind::MULT, mone, args[0]))); + Node s2 = nm->mkNode(Kind::SINE, nm->mkNode(Kind::MULT, mone, args[0])); return nm->mkNode(PLUS, s1, s2).eqNode(zero); } else if (id == PfRule::ARITH_TRANS_SINE_TANGENT_ZERO) @@ -289,13 +280,12 @@ Node TranscendentalProofRuleChecker::checkInternal( { Assert(children.empty()); Assert(args.size() == 6); - Assert(args[0].isConst() && args[0].getKind() == Kind::CONST_RATIONAL - && args[0].getConst<Rational>().isIntegral()); + Assert(args[0].isConst() && args[0].getType().isInteger()); Assert(args[1].getType().isReal()); Assert(args[2].getType().isReal()); Assert(args[3].getType().isReal()); - Assert(args[4].isConst() && args[4].getKind() == Kind::CONST_RATIONAL); - Assert(args[5].isConst() && args[5].getKind() == Kind::CONST_RATIONAL); + Assert(args[4].isConst() && args[4].getType().isRealOrInt()); + Assert(args[5].isConst() && args[5].getType().isRealOrInt()); std::uint64_t d = args[0].getConst<Rational>().getNumerator().toUnsignedInt(); Node t = args[1]; @@ -306,23 +296,21 @@ Node TranscendentalProofRuleChecker::checkInternal( TaylorGenerator tg; TaylorGenerator::ApproximationBounds bounds; tg.getPolynomialApproximationBounds(Kind::SINE, d / 2, bounds); - Node evall = Rewriter::rewrite( - bounds.d_upperNeg.substitute(tg.getTaylorVariable(), l)); - Node evalu = Rewriter::rewrite( - bounds.d_upperNeg.substitute(tg.getTaylorVariable(), u)); + Evaluator eval(nullptr); + Node evall = eval.eval(bounds.d_upperNeg, {tg.getTaylorVariable()}, {l}); + Node evalu = eval.eval(bounds.d_upperNeg, {tg.getTaylorVariable()}, {u}); Node lem = nm->mkNode( Kind::IMPLIES, mkBounds(t, lb, ub), nm->mkNode( Kind::LEQ, nm->mkNode(Kind::SINE, t), mkSecant(t, lb, ub, l, u))); - return Rewriter::rewrite(lem); + return lem; } else if (id == PfRule::ARITH_TRANS_SINE_APPROX_ABOVE_POS) { Assert(children.empty()); Assert(args.size() == 5); - Assert(args[0].isConst() && args[0].getKind() == Kind::CONST_RATIONAL - && args[0].getConst<Rational>().isIntegral()); + Assert(args[0].isConst() && args[0].getType().isInteger()); Assert(args[1].getType().isReal()); Assert(args[2].getType().isReal()); Assert(args[3].getType().isReal()); @@ -335,24 +323,22 @@ Node TranscendentalProofRuleChecker::checkInternal( TaylorGenerator tg; TaylorGenerator::ApproximationBounds bounds; tg.getPolynomialApproximationBounds(Kind::SINE, d / 2, bounds); - Node eval = Rewriter::rewrite( - bounds.d_upperPos.substitute(tg.getTaylorVariable(), c)); - return Rewriter::rewrite( - nm->mkNode(Kind::IMPLIES, - mkBounds(t, lb, ub), - nm->mkNode(Kind::LEQ, nm->mkNode(Kind::SINE, t), eval))); + Evaluator eval(nullptr); + Node evalc = eval.eval(bounds.d_upperPos, {tg.getTaylorVariable()}, {c}); + return nm->mkNode(Kind::IMPLIES, + mkBounds(t, lb, ub), + nm->mkNode(Kind::LEQ, nm->mkNode(Kind::SINE, t), evalc)); } else if (id == PfRule::ARITH_TRANS_SINE_APPROX_BELOW_POS) { Assert(children.empty()); Assert(args.size() == 6); - Assert(args[0].isConst() && args[0].getKind() == Kind::CONST_RATIONAL - && args[0].getConst<Rational>().isIntegral()); + Assert(args[0].isConst() && args[0].getType().isInteger()); Assert(args[1].getType().isReal()); Assert(args[2].getType().isReal()); Assert(args[3].getType().isReal()); - Assert(args[4].isConst() && args[4].getKind() == Kind::CONST_RATIONAL); - Assert(args[5].isConst() && args[5].getKind() == Kind::CONST_RATIONAL); + Assert(args[4].isConst() && args[4].getType().isRealOrInt()); + Assert(args[5].isConst() && args[5].getType().isRealOrInt()); std::uint64_t d = args[0].getConst<Rational>().getNumerator().toUnsignedInt(); Node t = args[1]; @@ -363,23 +349,21 @@ Node TranscendentalProofRuleChecker::checkInternal( TaylorGenerator tg; TaylorGenerator::ApproximationBounds bounds; tg.getPolynomialApproximationBounds(Kind::SINE, d / 2, bounds); - Node evall = - Rewriter::rewrite(bounds.d_lower.substitute(tg.getTaylorVariable(), l)); - Node evalu = - Rewriter::rewrite(bounds.d_lower.substitute(tg.getTaylorVariable(), u)); + Evaluator eval(nullptr); + Node evall = eval.eval(bounds.d_lower, {tg.getTaylorVariable()}, {l}); + Node evalu = eval.eval(bounds.d_lower, {tg.getTaylorVariable()}, {u}); Node lem = nm->mkNode( Kind::IMPLIES, mkBounds(t, lb, ub), nm->mkNode( Kind::GEQ, nm->mkNode(Kind::SINE, t), mkSecant(t, lb, ub, l, u))); - return Rewriter::rewrite(lem); + return lem; } else if (id == PfRule::ARITH_TRANS_SINE_APPROX_BELOW_NEG) { Assert(children.empty()); Assert(args.size() == 5); - Assert(args[0].isConst() && args[0].getKind() == Kind::CONST_RATIONAL - && args[0].getConst<Rational>().isIntegral()); + Assert(args[0].isConst() && args[0].getType().isInteger()); Assert(args[1].getType().isReal()); Assert(args[2].getType().isReal()); Assert(args[3].getType().isReal()); @@ -392,12 +376,11 @@ Node TranscendentalProofRuleChecker::checkInternal( TaylorGenerator tg; TaylorGenerator::ApproximationBounds bounds; tg.getPolynomialApproximationBounds(Kind::SINE, d / 2, bounds); - Node eval = - Rewriter::rewrite(bounds.d_lower.substitute(tg.getTaylorVariable(), c)); - return Rewriter::rewrite( - nm->mkNode(Kind::IMPLIES, - mkBounds(t, lb, ub), - nm->mkNode(Kind::GEQ, nm->mkNode(Kind::SINE, t), eval))); + Evaluator eval(nullptr); + Node evalc = eval.eval(bounds.d_lower, {tg.getTaylorVariable()}, {c}); + return nm->mkNode(Kind::IMPLIES, + mkBounds(t, lb, ub), + nm->mkNode(Kind::GEQ, nm->mkNode(Kind::SINE, t), evalc)); } return Node::null(); } diff --git a/src/theory/arith/nl/transcendental/sine_solver.cpp b/src/theory/arith/nl/transcendental/sine_solver.cpp index b6b5c92c1..6c1bec647 100644 --- a/src/theory/arith/nl/transcendental/sine_solver.cpp +++ b/src/theory/arith/nl/transcendental/sine_solver.cpp @@ -75,13 +75,12 @@ void SineSolver::doPhaseShift(TNode a, TNode new_a, TNode y) nm->mkNode(Kind::ITE, mkValidPhase(a[0], d_data->d_pi), a[0].eqNode(y), - a[0].eqNode(nm->mkNode( - Kind::PLUS, - y, - nm->mkNode(Kind::MULT, - nm->mkConst(CONST_RATIONAL, Rational(2)), - shift, - d_data->d_pi)))), + a[0].eqNode(nm->mkNode(Kind::PLUS, + y, + nm->mkNode(Kind::MULT, + nm->mkConstReal(Rational(2)), + shift, + d_data->d_pi)))), new_a.eqNode(a)); CDProof* proof = nullptr; if (d_data->isProofEnabled()) @@ -143,7 +142,17 @@ void SineSolver::checkInitialRefine() if (d_data->isProofEnabled()) { proof = d_data->getProof(); - proof->addStep(lem, PfRule::ARITH_TRANS_SINE_SYMMETRY, {}, {t[0]}); + Node tmplem = + nm->mkNode(Kind::PLUS, + t, + nm->mkNode( + Kind::SINE, + nm->mkNode(Kind::MULT, d_data->d_neg_one, t[0]))) + .eqNode(d_data->d_zero); + proof->addStep( + tmplem, PfRule::ARITH_TRANS_SINE_SYMMETRY, {}, {t[0]}); + proof->addStep( + lem, PfRule::MACRO_SR_PRED_TRANSFORM, {tmplem}, {lem}); } d_data->d_im.addPendingLemma( lem, InferenceId::ARITH_NL_T_INIT_REFINE, proof); @@ -385,9 +394,6 @@ void SineSolver::doTangentLemma( e, poly_approx)); - Trace("nl-ext-sine") << "*** Tangent plane lemma (pre-rewrite): " << lem - << std::endl; - lem = rewrite(lem); Trace("nl-ext-sine") << "*** Tangent plane lemma : " << lem << std::endl; Assert(d_data->d_model.computeAbstractModelValue(lem) == d_data->d_false); // Figure 3 : line 9 @@ -402,7 +408,7 @@ void SineSolver::doTangentLemma( proof->addStep(lem, PfRule::ARITH_TRANS_SINE_APPROX_BELOW_NEG, {}, - {nm->mkConst(CONST_RATIONAL, Rational(2 * d)), + {nm->mkConstInt(Rational(2 * d)), e[0], c, regionToLowerBound(region), @@ -413,7 +419,7 @@ void SineSolver::doTangentLemma( proof->addStep(lem, PfRule::ARITH_TRANS_SINE_APPROX_BELOW_NEG, {}, - {nm->mkConst(CONST_RATIONAL, Rational(2 * d)), + {nm->mkConstInt(Rational(2 * d)), e[0], c, c, @@ -427,7 +433,7 @@ void SineSolver::doTangentLemma( proof->addStep(lem, PfRule::ARITH_TRANS_SINE_APPROX_ABOVE_POS, {}, - {nm->mkConst(CONST_RATIONAL, Rational(2 * d)), + {nm->mkConstInt(Rational(2 * d)), e[0], c, regionToLowerBound(region), @@ -438,7 +444,7 @@ void SineSolver::doTangentLemma( proof->addStep(lem, PfRule::ARITH_TRANS_SINE_APPROX_ABOVE_POS, {}, - {nm->mkConst(CONST_RATIONAL, Rational(2 * d)), + {nm->mkConstInt(Rational(2 * d)), e[0], c, c, diff --git a/src/theory/arith/nl/transcendental/taylor_generator.cpp b/src/theory/arith/nl/transcendental/taylor_generator.cpp index 2a231bc2b..c9e0015e2 100644 --- a/src/theory/arith/nl/transcendental/taylor_generator.cpp +++ b/src/theory/arith/nl/transcendental/taylor_generator.cpp @@ -17,6 +17,7 @@ #include "theory/arith/arith_utilities.h" #include "theory/arith/nl/nl_model.h" +#include "theory/evaluator.h" #include "theory/rewriter.h" using namespace cvc5::kind; @@ -28,8 +29,8 @@ namespace nl { namespace transcendental { TaylorGenerator::TaylorGenerator() - : d_nm(NodeManager::currentNM()), - d_taylor_real_fv(d_nm->mkBoundVar("x", d_nm->realType())) + : d_taylor_real_fv(NodeManager::currentNM()->mkBoundVar( + "x", NodeManager::currentNM()->realType())) { } @@ -50,7 +51,7 @@ std::pair<Node, Node> TaylorGenerator::getTaylor(Kind k, std::uint64_t n) // the current factorial `counter!` Integer factorial = 1; // the current variable power `x^counter` - Node varpow = nm->mkConst(CONST_RATIONAL, Rational(1)); + Node varpow = nm->mkConstReal(Rational(1)); std::vector<Node> sum; for (std::uint64_t counter = 1; counter <= n; ++counter) { @@ -59,9 +60,7 @@ std::pair<Node, Node> TaylorGenerator::getTaylor(Kind k, std::uint64_t n) // Maclaurin series for exponential: // \sum_{n=0}^\infty x^n / n! sum.push_back( - nm->mkNode(Kind::DIVISION, - varpow, - nm->mkConst<Rational>(CONST_RATIONAL, factorial))); + nm->mkNode(Kind::DIVISION, varpow, nm->mkConstReal(factorial))); } else if (k == Kind::SINE) { @@ -70,24 +69,19 @@ std::pair<Node, Node> TaylorGenerator::getTaylor(Kind k, std::uint64_t n) if (counter % 2 == 0) { int sign = (counter % 4 == 0 ? -1 : 1); - sum.push_back(nm->mkNode( - Kind::MULT, - nm->mkNode(Kind::DIVISION, - nm->mkConst<Rational>(CONST_RATIONAL, sign), - nm->mkConst<Rational>(CONST_RATIONAL, factorial)), - varpow)); + sum.push_back(nm->mkNode(Kind::MULT, + nm->mkNode(Kind::DIVISION, + nm->mkConstReal(sign), + nm->mkConstReal(factorial)), + varpow)); } } factorial *= counter; - varpow = - Rewriter::rewrite(nm->mkNode(Kind::MULT, d_taylor_real_fv, varpow)); + varpow = nm->mkNode(Kind::MULT, d_taylor_real_fv, varpow); } - Node taylor_sum = - Rewriter::rewrite(sum.size() == 1 ? sum[0] : nm->mkNode(Kind::PLUS, sum)); - Node taylor_rem = Rewriter::rewrite( - nm->mkNode(Kind::DIVISION, - varpow, - nm->mkConst<Rational>(CONST_RATIONAL, factorial))); + Node taylor_sum = (sum.size() == 1 ? sum[0] : nm->mkNode(Kind::PLUS, sum)); + Node taylor_rem = + nm->mkNode(Kind::DIVISION, varpow, nm->mkConstReal(factorial)); auto res = std::make_pair(taylor_sum, taylor_rem); @@ -118,19 +112,17 @@ void TaylorGenerator::getPolynomialApproximationBounds( if (k == Kind::EXPONENTIAL) { pbounds.d_lower = taylor_sum; - pbounds.d_upperNeg = - Rewriter::rewrite(nm->mkNode(Kind::PLUS, taylor_sum, ru)); - pbounds.d_upperPos = Rewriter::rewrite(nm->mkNode( - Kind::MULT, - taylor_sum, - nm->mkNode( - Kind::PLUS, nm->mkConst(CONST_RATIONAL, Rational(1)), ru))); + pbounds.d_upperNeg = nm->mkNode(Kind::PLUS, taylor_sum, ru); + pbounds.d_upperPos = + nm->mkNode(Kind::MULT, + taylor_sum, + nm->mkNode(Kind::PLUS, nm->mkConstReal(Rational(1)), ru)); } else { Assert(k == Kind::SINE); - Node l = Rewriter::rewrite(nm->mkNode(Kind::MINUS, taylor_sum, ru)); - Node u = Rewriter::rewrite(nm->mkNode(Kind::PLUS, taylor_sum, ru)); + Node l = nm->mkNode(Kind::MINUS, taylor_sum, ru); + Node u = nm->mkNode(Kind::PLUS, taylor_sum, ru); pbounds.d_lower = l; pbounds.d_upperNeg = u; pbounds.d_upperPos = u; @@ -160,6 +152,7 @@ std::uint64_t TaylorGenerator::getPolynomialApproximationBoundForArg( std::uint64_t ds = d; TNode ttrf = getTaylorVariable(); TNode tc = c; + Evaluator eval(nullptr); do { success = true; @@ -167,8 +160,7 @@ std::uint64_t TaylorGenerator::getPolynomialApproximationBoundForArg( std::pair<Node, Node> taylor = getTaylor(k, n); // check that 1-c^{n+1}/(n+1)! > 0 Node ru = taylor.second; - Node rus = ru.substitute(ttrf, tc); - rus = Rewriter::rewrite(rus); + Node rus = eval.eval(ru, {ttrf}, {tc}); Assert(rus.isConst()); if (rus.getConst<Rational>() > 1) { @@ -206,11 +198,11 @@ std::pair<Node, Node> TaylorGenerator::getTfModelBounds(Node tf, // at zero, its trivial if (k == Kind::SINE) { - Node zero = nm->mkConst(CONST_RATIONAL, Rational(0)); + Node zero = nm->mkConstReal(Rational(0)); return std::pair<Node, Node>(zero, zero); } Assert(k == Kind::EXPONENTIAL); - Node one = nm->mkConst(CONST_RATIONAL, Rational(1)); + Node one = nm->mkConstReal(Rational(1)); return std::pair<Node, Node>(one, one); } bool isNeg = csign == -1; @@ -221,6 +213,7 @@ std::pair<Node, Node> TaylorGenerator::getTfModelBounds(Node tf, std::vector<Node> bounds; TNode tfv = getTaylorVariable(); TNode tfs = tf[0]; + Evaluator eval(nullptr); for (unsigned d2 = 0; d2 < 2; d2++) { Node pab = (d2 == 0 ? pbounds.d_lower @@ -235,8 +228,7 @@ std::pair<Node, Node> TaylorGenerator::getTfModelBounds(Node tf, // M_A( x*x { x -> t } ) = M_A( t*t ) // where M_A denotes the abstract model. Node mtfs = model.computeAbstractModelValue(tfs); - pab = pab.substitute(tfv, mtfs); - pab = Rewriter::rewrite(pab); + pab = eval.eval(pab, {tfv}, {mtfs}); Assert(pab.isConst()); bounds.push_back(pab); } diff --git a/src/theory/arith/nl/transcendental/taylor_generator.h b/src/theory/arith/nl/transcendental/taylor_generator.h index df4cb128c..ea082d87b 100644 --- a/src/theory/arith/nl/transcendental/taylor_generator.h +++ b/src/theory/arith/nl/transcendental/taylor_generator.h @@ -104,7 +104,6 @@ class TaylorGenerator NlModel& model); private: - NodeManager* d_nm; const Node d_taylor_real_fv; /** diff --git a/src/theory/arith/nl/transcendental/transcendental_solver.cpp b/src/theory/arith/nl/transcendental/transcendental_solver.cpp index 9e204f582..25a5a511f 100644 --- a/src/theory/arith/nl/transcendental/transcendental_solver.cpp +++ b/src/theory/arith/nl/transcendental/transcendental_solver.cpp @@ -41,11 +41,11 @@ TranscendentalSolver::TranscendentalSolver(Env& env, InferenceManager& im, NlModel& m) : EnvObj(env), - d_tstate(im, m, env), + d_tstate(env, im, m), d_expSlv(env, &d_tstate), d_sineSlv(env, &d_tstate) { - d_taylor_degree = d_tstate.d_env.getOptions().arith.nlExtTfTaylorDegree; + d_taylor_degree = options().arith.nlExtTfTaylorDegree; } TranscendentalSolver::~TranscendentalSolver() {} @@ -187,7 +187,7 @@ void TranscendentalSolver::processSideEffect(const NlLemma& se) auto it = secant_points.find(d); if (it == secant_points.end()) { - it = secant_points.emplace(d, d_tstate.d_env.getUserContext()).first; + it = secant_points.emplace(d, userContext()).first; } it->second.push_back(c); } diff --git a/src/theory/arith/nl/transcendental/transcendental_state.cpp b/src/theory/arith/nl/transcendental/transcendental_state.cpp index 870eddc86..e32f336ac 100644 --- a/src/theory/arith/nl/transcendental/transcendental_state.cpp +++ b/src/theory/arith/nl/transcendental/transcendental_state.cpp @@ -30,16 +30,16 @@ namespace arith { namespace nl { namespace transcendental { -TranscendentalState::TranscendentalState(InferenceManager& im, - NlModel& model, - Env& env) - : d_im(im), d_model(model), d_env(env) +TranscendentalState::TranscendentalState(Env& env, + InferenceManager& im, + NlModel& model) + : EnvObj(env), d_im(im), d_model(model) { d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); - d_zero = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(0)); - d_one = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(1)); - d_neg_one = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(-1)); + d_zero = NodeManager::currentNM()->mkConstReal(Rational(0)); + d_one = NodeManager::currentNM()->mkConstReal(Rational(1)); + d_neg_one = NodeManager::currentNM()->mkConstReal(Rational(-1)); if (d_env.isTheoryProofProducing()) { d_proof.reset(new CDProofSet<CDProof>( @@ -204,21 +204,15 @@ void TranscendentalState::mkPi() if (d_pi.isNull()) { d_pi = nm->mkNullaryOperator(nm->realType(), Kind::PI); - d_pi_2 = Rewriter::rewrite( - nm->mkNode(Kind::MULT, - d_pi, - nm->mkConst(CONST_RATIONAL, Rational(1) / Rational(2)))); - d_pi_neg_2 = Rewriter::rewrite( - nm->mkNode(Kind::MULT, - d_pi, - nm->mkConst(CONST_RATIONAL, Rational(-1) / Rational(2)))); - d_pi_neg = Rewriter::rewrite(nm->mkNode( - Kind::MULT, d_pi, nm->mkConst(CONST_RATIONAL, Rational(-1)))); + d_pi_2 = rewrite(nm->mkNode( + Kind::MULT, d_pi, nm->mkConstReal(Rational(1) / Rational(2)))); + d_pi_neg_2 = rewrite(nm->mkNode( + Kind::MULT, d_pi, nm->mkConstReal(Rational(-1) / Rational(2)))); + d_pi_neg = + rewrite(nm->mkNode(Kind::MULT, d_pi, nm->mkConstReal(Rational(-1)))); // initialize bounds - d_pi_bound[0] = - nm->mkConst(CONST_RATIONAL, Rational(103993) / Rational(33102)); - d_pi_bound[1] = - nm->mkConst(CONST_RATIONAL, Rational(104348) / Rational(33215)); + d_pi_bound[0] = nm->mkConstReal(Rational(103993) / Rational(33102)); + d_pi_bound[1] = nm->mkConstReal(Rational(104348) / Rational(33215)); } } @@ -274,7 +268,7 @@ Node TranscendentalState::mkSecantPlane( { NodeManager* nm = NodeManager::currentNM(); // Figure 3: S_l( x ), S_u( x ) for s = 0,1 - Node rcoeff_n = Rewriter::rewrite(nm->mkNode(Kind::MINUS, lower, upper)); + Node rcoeff_n = rewrite(nm->mkNode(Kind::MINUS, lower, upper)); Assert(rcoeff_n.isConst()); Rational rcoeff = rcoeff_n.getConst<Rational>(); Assert(rcoeff.sgn() != 0); @@ -291,7 +285,7 @@ Node TranscendentalState::mkSecantPlane( Trace("nl-trans") << "\tfrom ( " << lower << " ; " << lval << " ) to ( " << upper << " ; " << uval << " )" << std::endl; Trace("nl-trans") << "\t" << res << std::endl; - Trace("nl-trans") << "\trewritten: " << Rewriter::rewrite(res) << std::endl; + Trace("nl-trans") << "\trewritten: " << rewrite(res) << std::endl; return res; } @@ -331,9 +325,6 @@ NlLemma TranscendentalState::mkSecantLemma(TNode lower, antec_n, nm->mkNode( convexity == Convexity::CONVEX ? Kind::LEQ : Kind::GEQ, tf, splane)); - Trace("nl-trans-lemma") << "*** Secant plane lemma (pre-rewrite) : " << lem - << std::endl; - lem = Rewriter::rewrite(lem); Trace("nl-trans-lemma") << "*** Secant plane lemma : " << lem << std::endl; Assert(d_model.computeAbstractModelValue(lem) == d_false); CDProof* proof = nullptr; @@ -347,44 +338,34 @@ NlLemma TranscendentalState::mkSecantLemma(TNode lower, proof->addStep(lem, PfRule::ARITH_TRANS_EXP_APPROX_ABOVE_POS, {}, - {nm->mkConst<Rational>(CONST_RATIONAL, 2 * actual_d), - tf[0], - lower, - upper}); + {nm->mkConstInt(2 * actual_d), tf[0], lower, upper}); } else { proof->addStep(lem, PfRule::ARITH_TRANS_EXP_APPROX_ABOVE_NEG, {}, - {nm->mkConst<Rational>(CONST_RATIONAL, 2 * actual_d), - tf[0], - lower, - upper}); + {nm->mkConstInt(2 * actual_d), tf[0], lower, upper}); } } else if (tf.getKind() == Kind::SINE) { if (convexity == Convexity::CONCAVE) { - proof->addStep(lem, - PfRule::ARITH_TRANS_SINE_APPROX_BELOW_POS, - {}, - {nm->mkConst<Rational>(CONST_RATIONAL, 2 * actual_d), - tf[0], - lower, - upper, - lapprox, - uapprox + proof->addStep( + lem, + PfRule::ARITH_TRANS_SINE_APPROX_BELOW_POS, + {}, + {nm->mkConstInt(2 * actual_d), tf[0], lower, upper, lapprox, uapprox - }); + }); } else { proof->addStep(lem, PfRule::ARITH_TRANS_SINE_APPROX_ABOVE_NEG, {}, - {nm->mkConst<Rational>(CONST_RATIONAL, 2 * actual_d), + {nm->mkConstInt(2 * actual_d), tf[0], lower, upper, @@ -419,8 +400,8 @@ void TranscendentalState::doSecantLemmas(const std::pair<Node, Node>& bounds, if (lower != center) { // Figure 3 : P(l), P(u), for s = 0 - Node lval = Rewriter::rewrite( - poly_approx.substitute(d_taylor.getTaylorVariable(), lower)); + Node lval = + rewrite(poly_approx.substitute(d_taylor.getTaylorVariable(), lower)); Node splane = mkSecantPlane(tf[0], lower, center, lval, cval); NlLemma nlem = mkSecantLemma( lower, center, lval, cval, csign, convexity, tf, splane, actual_d); @@ -438,8 +419,8 @@ void TranscendentalState::doSecantLemmas(const std::pair<Node, Node>& bounds, if (center != upper) { // Figure 3 : P(l), P(u), for s = 1 - Node uval = Rewriter::rewrite( - poly_approx.substitute(d_taylor.getTaylorVariable(), upper)); + Node uval = + rewrite(poly_approx.substitute(d_taylor.getTaylorVariable(), upper)); Node splane = mkSecantPlane(tf[0], center, upper, cval, uval); NlLemma nlem = mkSecantLemma( center, upper, cval, uval, csign, convexity, tf, splane, actual_d); diff --git a/src/theory/arith/nl/transcendental/transcendental_state.h b/src/theory/arith/nl/transcendental/transcendental_state.h index 77fcf57fb..ede8079a4 100644 --- a/src/theory/arith/nl/transcendental/transcendental_state.h +++ b/src/theory/arith/nl/transcendental/transcendental_state.h @@ -60,9 +60,9 @@ inline std::ostream& operator<<(std::ostream& os, Convexity c) { * This includes common lookups and caches as well as generic utilities for * secant plane lemmas and taylor approximations. */ -struct TranscendentalState +struct TranscendentalState : protected EnvObj { - TranscendentalState(InferenceManager& im, NlModel& model, Env& env); + TranscendentalState(Env& env, InferenceManager& im, NlModel& model); /** * Checks whether proofs are enabled. @@ -168,8 +168,6 @@ struct TranscendentalState InferenceManager& d_im; /** Reference to the non-linear model object */ NlModel& d_model; - /** Reference to the environment */ - Env& d_env; /** Utility to compute taylor approximations */ TaylorGenerator d_taylor; /** |