summaryrefslogtreecommitdiff
path: root/src/theory
diff options
context:
space:
mode:
authorGereon Kremer <nafur42@gmail.com>2021-10-05 13:06:53 -0700
committerGitHub <noreply@github.com>2021-10-05 20:06:53 +0000
commite64a4bc87d2d98e04e8450d4ad9856bce3494c78 (patch)
tree339fd9f21009acdb7190516f62d38e247d042195 /src/theory
parent35de43d011528e34f42363c1d201b10bf254d386 (diff)
First round of refactoring on NlModel (#7255)
This PR performs a first refactoring on the NlModel class. It improves model value computation, comparison and stores the model substitutions in a map (instead of two vectors).
Diffstat (limited to 'src/theory')
-rw-r--r--src/theory/arith/arith_utilities.cpp19
-rw-r--r--src/theory/arith/arith_utilities.h5
-rw-r--r--src/theory/arith/nl/cad_solver.cpp4
-rw-r--r--src/theory/arith/nl/ext/monomial_check.cpp4
-rw-r--r--src/theory/arith/nl/nl_lemma_utils.cpp2
-rw-r--r--src/theory/arith/nl/nl_model.cpp294
-rw-r--r--src/theory/arith/nl/nl_model.h99
-rw-r--r--src/theory/arith/nl/nonlinear_extension.cpp9
-rw-r--r--src/theory/arith/nl/poly_conversion.h4
-rw-r--r--src/theory/arith/nl/transcendental/transcendental_solver.cpp16
10 files changed, 211 insertions, 245 deletions
diff --git a/src/theory/arith/arith_utilities.cpp b/src/theory/arith/arith_utilities.cpp
index 75edc49f5..5645542d0 100644
--- a/src/theory/arith/arith_utilities.cpp
+++ b/src/theory/arith/arith_utilities.cpp
@@ -203,31 +203,26 @@ void printRationalApprox(const char* c, Node cr, unsigned prec)
}
}
-Node arithSubstitute(Node n, std::vector<Node>& vars, std::vector<Node>& subs)
+Node arithSubstitute(Node n, const Subs& sub)
{
- Assert(vars.size() == subs.size());
NodeManager* nm = NodeManager::currentNM();
std::unordered_map<TNode, Node> visited;
- std::unordered_map<TNode, Node>::iterator it;
- std::vector<Node>::iterator itv;
std::vector<TNode> visit;
- TNode cur;
- Kind ck;
visit.push_back(n);
do
{
- cur = visit.back();
+ TNode cur = visit.back();
visit.pop_back();
- it = visited.find(cur);
+ auto it = visited.find(cur);
if (it == visited.end())
{
visited[cur] = Node::null();
- ck = cur.getKind();
- itv = std::find(vars.begin(), vars.end(), cur);
- if (itv != vars.end())
+ Kind ck = cur.getKind();
+ auto s = sub.find(cur);
+ if (s)
{
- visited[cur] = subs[std::distance(vars.begin(), itv)];
+ visited[cur] = *s;
}
else if (cur.getNumChildren() == 0)
{
diff --git a/src/theory/arith/arith_utilities.h b/src/theory/arith/arith_utilities.h
index b842ae58e..0d7f214d7 100644
--- a/src/theory/arith/arith_utilities.h
+++ b/src/theory/arith/arith_utilities.h
@@ -24,6 +24,7 @@
#include "context/cdhashset.h"
#include "expr/node.h"
+#include "expr/subs.h"
#include "theory/arith/arithvar.h"
#include "util/dense_map.h"
#include "util/integer.h"
@@ -313,13 +314,13 @@ void printRationalApprox(const char* c, Node cr, unsigned prec = 5);
/** Arithmetic substitute
*
- * This computes the substitution n { vars -> subs }, but with the caveat
+ * This computes the substitution n { subs }, but with the caveat
* that subterms of n that belong to a theory other than arithmetic are
* not traversed. In other words, terms that belong to other theories are
* treated as atomic variables. For example:
* (5*f(x) + 7*x ){ x -> 3 } returns 5*f(x) + 7*3.
*/
-Node arithSubstitute(Node n, std::vector<Node>& vars, std::vector<Node>& subs);
+Node arithSubstitute(Node n, const Subs& sub);
/** Make the node u >= a ^ a >= l */
Node mkBounded(Node l, Node a, Node u);
diff --git a/src/theory/arith/nl/cad_solver.cpp b/src/theory/arith/nl/cad_solver.cpp
index ebaeb9d61..132cb9795 100644
--- a/src/theory/arith/nl/cad_solver.cpp
+++ b/src/theory/arith/nl/cad_solver.cpp
@@ -180,11 +180,11 @@ bool CadSolver::constructModelIfAvailable(std::vector<Node>& assertions)
Node value = value_to_node(d_CAC.getModel().get(v), d_ranVariable);
if (value.isConst())
{
- d_model.addCheckModelSubstitution(variable, value);
+ d_model.addSubstitution(variable, value);
}
else
{
- d_model.addCheckModelWitness(variable, value);
+ d_model.addWitness(variable, value);
}
Trace("nl-cad") << "-> " << v << " = " << value << std::endl;
}
diff --git a/src/theory/arith/nl/ext/monomial_check.cpp b/src/theory/arith/nl/ext/monomial_check.cpp
index 330cd57a3..b077dcfd0 100644
--- a/src/theory/arith/nl/ext/monomial_check.cpp
+++ b/src/theory/arith/nl/ext/monomial_check.cpp
@@ -402,7 +402,7 @@ bool MonomialCheck::compareMonomial(
if (a_index == vla.size() && b_index == vlb.size())
{
// finished, compare absolute value of abstract model values
- int modelStatus = d_data->d_model.compare(oa, ob, false, true) * -2;
+ int modelStatus = d_data->d_model.compare(oa, ob, false, true) * 2;
Trace("nl-ext-comp") << "...finished comparison with " << oa << " <"
<< status << "> " << ob
<< ", model status = " << modelStatus << std::endl;
@@ -677,7 +677,7 @@ void MonomialCheck::assignOrderIds(std::vector<Node>& vars,
{
Node vv = d_data->d_model.computeModelValue(
d_order_points[order_index], isConcrete);
- if (d_data->d_model.compareValue(v, vv, isAbsolute) <= 0)
+ if (d_data->d_model.compareValue(v, vv, isAbsolute) >= 0)
{
counter++;
Trace("nl-ext-mvo") << "O[" << d_order_points[order_index]
diff --git a/src/theory/arith/nl/nl_lemma_utils.cpp b/src/theory/arith/nl/nl_lemma_utils.cpp
index 3e2ebe87e..18e296da7 100644
--- a/src/theory/arith/nl/nl_lemma_utils.cpp
+++ b/src/theory/arith/nl/nl_lemma_utils.cpp
@@ -45,7 +45,7 @@ bool SortNlModel::operator()(Node i, Node j)
{
return i < j;
}
- return d_reverse_order ? cv < 0 : cv > 0;
+ return d_reverse_order ? cv > 0 : cv < 0;
}
bool SortNonlinearDegree::operator()(Node i, Node j)
diff --git a/src/theory/arith/nl/nl_model.cpp b/src/theory/arith/nl/nl_model.cpp
index ca75a1a06..427d203ea 100644
--- a/src/theory/arith/nl/nl_model.cpp
+++ b/src/theory/arith/nl/nl_model.cpp
@@ -43,18 +43,12 @@ NlModel::NlModel() : d_used_approx(false)
NlModel::~NlModel() {}
-void NlModel::reset(TheoryModel* m, std::map<Node, Node>& arithModel)
+void NlModel::reset(TheoryModel* m, const std::map<Node, Node>& arithModel)
{
d_model = m;
- d_mv[0].clear();
- d_mv[1].clear();
- d_arithVal.clear();
- // process arithModel
- std::map<Node, Node>::iterator it;
- for (const std::pair<const Node, Node>& m2 : arithModel)
- {
- d_arithVal[m2.first] = m2.second;
- }
+ d_concreteModelCache.clear();
+ d_abstractModelCache.clear();
+ d_arithVal = arithModel;
}
void NlModel::resetCheck()
@@ -63,46 +57,42 @@ void NlModel::resetCheck()
d_check_model_solved.clear();
d_check_model_bounds.clear();
d_check_model_witnesses.clear();
- d_check_model_vars.clear();
- d_check_model_subs.clear();
+ d_substitutions.clear();
}
-Node NlModel::computeConcreteModelValue(Node n)
+Node NlModel::computeConcreteModelValue(TNode n)
{
return computeModelValue(n, true);
}
-Node NlModel::computeAbstractModelValue(Node n)
+Node NlModel::computeAbstractModelValue(TNode n)
{
return computeModelValue(n, false);
}
-Node NlModel::computeModelValue(Node n, bool isConcrete)
+Node NlModel::computeModelValue(TNode n, bool isConcrete)
{
- unsigned index = isConcrete ? 0 : 1;
- std::map<Node, Node>::iterator it = d_mv[index].find(n);
- if (it != d_mv[index].end())
+ auto& cache = isConcrete ? d_concreteModelCache : d_abstractModelCache;
+ if (auto it = cache.find(n); it != cache.end())
{
return it->second;
}
- Trace("nl-ext-mv-debug") << "computeModelValue " << n << ", index=" << index
- << std::endl;
+ Trace("nl-ext-mv-debug") << "computeModelValue " << n
+ << ", isConcrete=" << isConcrete << std::endl;
Node ret;
- Kind nk = n.getKind();
if (n.isConst())
{
ret = n;
}
- else if (!isConcrete && hasTerm(n))
+ else if (!isConcrete && hasLinearModelValue(n, ret))
{
// use model value for abstraction
- ret = getRepresentative(n);
}
else if (n.getNumChildren() == 0)
{
// we are interested in the exact value of PI, which cannot be computed.
// hence, we return PI itself when asked for the concrete value.
- if (nk == PI)
+ if (n.getKind() == PI)
{
ret = n;
}
@@ -114,7 +104,7 @@ Node NlModel::computeModelValue(Node n, bool isConcrete)
else
{
// otherwise, compute true value
- TheoryId ctid = theory::kindToTheoryId(nk);
+ TheoryId ctid = theory::kindToTheoryId(n.getKind());
if (ctid != THEORY_ARITH && ctid != THEORY_BOOL && ctid != THEORY_BUILTIN)
{
// we directly look up terms not belonging to arithmetic
@@ -125,65 +115,28 @@ Node NlModel::computeModelValue(Node n, bool isConcrete)
std::vector<Node> children;
if (n.getMetaKind() == metakind::PARAMETERIZED)
{
- children.push_back(n.getOperator());
+ children.emplace_back(n.getOperator());
}
- for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
+ for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
{
- Node mc = computeModelValue(n[i], isConcrete);
- children.push_back(mc);
+ children.emplace_back(computeModelValue(n[i], isConcrete));
}
- ret = NodeManager::currentNM()->mkNode(nk, children);
+ ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
ret = Rewriter::rewrite(ret);
}
}
- Trace("nl-ext-mv-debug") << "computed " << (index == 0 ? "M" : "M_A") << "["
+ Trace("nl-ext-mv-debug") << "computed " << (isConcrete ? "M" : "M_A") << "["
<< n << "] = " << ret << std::endl;
- d_mv[index][n] = ret;
+ cache[n] = ret;
return ret;
}
-bool NlModel::hasTerm(Node n) const
+int NlModel::compare(TNode i, TNode j, bool isConcrete, bool isAbsolute)
{
- return d_arithVal.find(n) != d_arithVal.end();
-}
-
-Node NlModel::getRepresentative(Node n) const
-{
- if (n.isConst())
- {
- return n;
- }
- std::map<Node, Node>::const_iterator it = d_arithVal.find(n);
- if (it != d_arithVal.end())
- {
- AlwaysAssert(it->second.isConst());
- return it->second;
- }
- return d_model->getRepresentative(n);
-}
-
-Node NlModel::getValueInternal(Node n)
-{
- if (n.isConst())
+ if (i == j)
{
- return n;
+ return 0;
}
- std::map<Node, Node>::const_iterator it = d_arithVal.find(n);
- if (it != d_arithVal.end())
- {
- AlwaysAssert(it->second.isConst());
- return it->second;
- }
- // It is unconstrained in the model, return 0. We additionally add it
- // to mapping from the linear solver. This ensures that if the nonlinear
- // solver assumes that n = 0, then this assumption is recorded in the overall
- // model.
- d_arithVal[n] = d_zero;
- return d_zero;
-}
-
-int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute)
-{
Node ci = computeModelValue(i, isConcrete);
Node cj = computeModelValue(j, isConcrete);
if (ci.isConst())
@@ -197,27 +150,24 @@ int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute)
return cj.isConst() ? -1 : 0;
}
-int NlModel::compareValue(Node i, Node j, bool isAbsolute) const
+int NlModel::compareValue(TNode i, TNode j, bool isAbsolute) const
{
Assert(i.isConst() && j.isConst());
- int ret;
if (i == j)
{
- ret = 0;
+ return 0;
}
- else if (!isAbsolute)
+ if (!isAbsolute)
{
- ret = i.getConst<Rational>() < j.getConst<Rational>() ? 1 : -1;
+ return i.getConst<Rational>() < j.getConst<Rational>() ? -1 : 1;
}
- else
+ Rational iabs = i.getConst<Rational>().abs();
+ Rational jabs = j.getConst<Rational>().abs();
+ if (iabs == jabs)
{
- ret = (i.getConst<Rational>().abs() == j.getConst<Rational>().abs()
- ? 0
- : (i.getConst<Rational>().abs() < j.getConst<Rational>().abs()
- ? 1
- : -1));
+ return 0;
}
- return ret;
+ return iabs < jabs ? -1 : 1;
}
bool NlModel::checkModel(const std::vector<Node>& assertions,
@@ -262,7 +212,7 @@ bool NlModel::checkModel(const std::vector<Node>& assertions,
&& !isTranscendentalKind(k))
{
// if we have not set an approximate bound for it
- if (!hasCheckModelAssignment(cur))
+ if (!hasAssignment(cur))
{
// set its exact model value in the substitution
Node curv = computeConcreteModelValue(cur);
@@ -273,7 +223,7 @@ bool NlModel::checkModel(const std::vector<Node>& assertions,
printRationalApprox("nl-ext-cm", curv);
Trace("nl-ext-cm") << std::endl;
}
- bool ret = addCheckModelSubstitution(cur, curv);
+ bool ret = addSubstitution(cur, curv);
AlwaysAssert(ret);
}
}
@@ -294,10 +244,9 @@ bool NlModel::checkModel(const std::vector<Node>& assertions,
{
Node av = a;
// apply the substitution to a
- if (!d_check_model_vars.empty())
+ if (!d_substitutions.empty())
{
- av = arithSubstitute(av, d_check_model_vars, d_check_model_subs);
- av = Rewriter::rewrite(av);
+ av = Rewriter::rewrite(arithSubstitute(av, d_substitutions));
}
// simple check literal
if (!simpleCheckModelLit(av))
@@ -321,14 +270,13 @@ bool NlModel::checkModel(const std::vector<Node>& assertions,
return true;
}
-bool NlModel::addCheckModelSubstitution(TNode v, TNode s)
+bool NlModel::addSubstitution(TNode v, TNode s)
{
// should not substitute the same variable twice
Trace("nl-ext-model") << "* check model substitution : " << v << " -> " << s
<< std::endl;
// should not set exact bound more than once
- if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
- != d_check_model_vars.end())
+ if (d_substitutions.contains(v))
{
Trace("nl-ext-model") << "...ERROR: already has value." << std::endl;
// this should never happen since substitutions should be applied eagerly
@@ -352,37 +300,31 @@ bool NlModel::addCheckModelSubstitution(TNode v, TNode s)
Assert(d_check_model_witnesses.find(v) == d_check_model_witnesses.end())
<< "We tried to add a substitution where we already had a witness term."
<< std::endl;
- std::vector<Node> varsTmp;
- varsTmp.push_back(v);
- std::vector<Node> subsTmp;
- subsTmp.push_back(s);
- for (unsigned i = 0, size = d_check_model_subs.size(); i < size; i++)
+ Subs tmp;
+ tmp.add(v, s);
+ for (auto& sub : d_substitutions.d_subs)
{
- Node ms = d_check_model_subs[i];
- Node mss = arithSubstitute(ms, varsTmp, subsTmp);
- if (mss != ms)
+ Node ms = arithSubstitute(sub, tmp);
+ if (ms != sub)
{
- mss = Rewriter::rewrite(mss);
+ sub = Rewriter::rewrite(ms);
}
- d_check_model_subs[i] = mss;
}
- d_check_model_vars.push_back(v);
- d_check_model_subs.push_back(s);
+ d_substitutions.add(v, s);
return true;
}
-bool NlModel::addCheckModelBound(TNode v, TNode l, TNode u)
+bool NlModel::addBound(TNode v, TNode l, TNode u)
{
Trace("nl-ext-model") << "* check model bound : " << v << " -> [" << l << " "
<< u << "]" << std::endl;
if (l == u)
{
// bound is exact, can add as substitution
- return addCheckModelSubstitution(v, l);
+ return addSubstitution(v, l);
}
// should not set a bound for a value that is exact
- if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
- != d_check_model_vars.end())
+ if (d_substitutions.contains(v))
{
Trace("nl-ext-model")
<< "...ERROR: setting bound for variable that already has exact value."
@@ -405,13 +347,12 @@ bool NlModel::addCheckModelBound(TNode v, TNode l, TNode u)
return true;
}
-bool NlModel::addCheckModelWitness(TNode v, TNode w)
+bool NlModel::addWitness(TNode v, TNode w)
{
Trace("nl-ext-model") << "* check model witness : " << v << " -> " << w
<< std::endl;
// should not set a witness for a value that is already set
- if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
- != d_check_model_vars.end())
+ if (d_substitutions.contains(v))
{
Trace("nl-ext-model") << "...ERROR: setting witness for variable that "
"already has a constant value."
@@ -423,20 +364,6 @@ bool NlModel::addCheckModelWitness(TNode v, TNode w)
return true;
}
-bool NlModel::hasCheckModelAssignment(Node v) const
-{
- if (d_check_model_bounds.find(v) != d_check_model_bounds.end())
- {
- return true;
- }
- if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end())
- {
- return true;
- }
- return std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
- != d_check_model_vars.end();
-}
-
void NlModel::setUsedApproximate() { d_used_approx = true; }
bool NlModel::usedApproximate() const { return d_used_approx; }
@@ -446,9 +373,9 @@ bool NlModel::solveEqualitySimple(Node eq,
std::vector<NlLemma>& lemmas)
{
Node seq = eq;
- if (!d_check_model_vars.empty())
+ if (!d_substitutions.empty())
{
- seq = arithSubstitute(eq, d_check_model_vars, d_check_model_subs);
+ seq = arithSubstitute(eq, d_substitutions);
seq = Rewriter::rewrite(seq);
if (seq.isConst())
{
@@ -545,7 +472,7 @@ bool NlModel::solveEqualitySimple(Node eq,
{
Trace("nl-ext-cm-debug") << "check subs var : " << uv << std::endl;
// cannot already have a bound
- if (uv.isVar() && !hasCheckModelAssignment(uv))
+ if (uv.isVar() && !hasAssignment(uv))
{
Node slv;
Node veqc;
@@ -560,7 +487,7 @@ bool NlModel::solveEqualitySimple(Node eq,
{
Trace("nl-ext-cm")
<< "check-model-subs : " << uv << " -> " << slv << std::endl;
- bool ret = addCheckModelSubstitution(uv, slv);
+ bool ret = addSubstitution(uv, slv);
if (ret)
{
Trace("nl-ext-cms") << "...success, model substitution " << uv
@@ -577,7 +504,7 @@ bool NlModel::solveEqualitySimple(Node eq,
{
Trace("nl-ext-cm-debug") << "check set var : " << uvf << std::endl;
// cannot already have a bound
- if (uvf.isVar() && !hasCheckModelAssignment(uvf))
+ if (uvf.isVar() && !hasAssignment(uvf))
{
Node uvfv = computeConcreteModelValue(uvf);
if (Trace.isOn("nl-ext-cm"))
@@ -586,7 +513,7 @@ bool NlModel::solveEqualitySimple(Node eq,
printRationalApprox("nl-ext-cm", uvfv);
Trace("nl-ext-cm") << std::endl;
}
- bool ret = addCheckModelSubstitution(uvf, uvfv);
+ bool ret = addSubstitution(uvf, uvfv);
// recurse
return ret ? solveEqualitySimple(eq, d, lemmas) : false;
}
@@ -618,7 +545,7 @@ bool NlModel::solveEqualitySimple(Node eq,
printRationalApprox("nl-ext-cm", val);
Trace("nl-ext-cm") << std::endl;
}
- bool ret = addCheckModelSubstitution(var, val);
+ bool ret = addSubstitution(var, val);
if (ret)
{
Trace("nl-ext-cms") << "...success, solved linear." << std::endl;
@@ -647,7 +574,7 @@ bool NlModel::solveEqualitySimple(Node eq,
Trace("nl-ext-cms") << "...fail due to negative discriminant." << std::endl;
return false;
}
- if (hasCheckModelAssignment(var))
+ if (hasAssignment(var))
{
Trace("nl-ext-cms") << "...fail due to bounds on variable to solve for."
<< std::endl;
@@ -730,8 +657,7 @@ bool NlModel::solveEqualitySimple(Node eq,
printRationalApprox("nl-ext-cm", bounds[r_use_index][1]);
Trace("nl-ext-cm") << std::endl;
}
- bool ret =
- addCheckModelBound(var, bounds[r_use_index][0], bounds[r_use_index][1]);
+ bool ret = addBound(var, bounds[r_use_index][0], bounds[r_use_index][1]);
if (ret)
{
d_check_model_solved[eq] = var;
@@ -829,8 +755,7 @@ bool NlModel::simpleCheckModelLit(Node lit)
? vs_invalid[0]
: nm->mkNode(PLUS, vs_invalid));
// substitution to try
- std::vector<Node> qvars;
- std::vector<Node> qsubs;
+ Subs qsub;
for (const Node& v : vs)
{
// is it a valid variable?
@@ -882,7 +807,7 @@ bool NlModel::simpleCheckModelLit(Node lit)
Assert(boundn[0].getConst<Rational>()
<= boundn[1].getConst<Rational>());
Node s;
- qvars.push_back(v);
+ qsub.add(v, Node());
if (cmp[0] != cmp[1])
{
Assert(!cmp[0] && cmp[1]);
@@ -899,10 +824,9 @@ bool NlModel::simpleCheckModelLit(Node lit)
Node tcmpn[2];
for (unsigned r = 0; r < 2; r++)
{
- qsubs.push_back(boundn[r]);
- Node ts = arithSubstitute(t, qvars, qsubs);
+ qsub.d_subs.back() = boundn[r];
+ Node ts = arithSubstitute(t, qsub);
tcmpn[r] = Rewriter::rewrite(ts);
- qsubs.pop_back();
}
Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]);
Trace("nl-ext-cms-debug")
@@ -932,16 +856,15 @@ bool NlModel::simpleCheckModelLit(Node lit)
s = boundn[bindex_use];
}
Assert(!s.isNull());
- qsubs.push_back(s);
+ qsub.d_subs.back() = s;
Trace("nl-ext-cms") << "* set bound based on quadratic : " << v
<< " -> " << s << std::endl;
}
}
}
- if (!qvars.empty())
+ if (!qsub.empty())
{
- Assert(qvars.size() == qsubs.size());
- Node slit = arithSubstitute(lit, qvars, qsubs);
+ Node slit = arithSubstitute(lit, qsub);
slit = Rewriter::rewrite(slit);
return simpleCheckModelLit(slit);
}
@@ -1242,21 +1165,26 @@ void NlModel::printModelValue(const char* c, Node n, unsigned prec) const
if (Trace.isOn(c))
{
Trace(c) << " " << n << " -> ";
- for (int i = 1; i >= 0; --i)
+ const Node& aval = d_abstractModelCache.at(n);
+ if (aval.isConst())
{
- std::map<Node, Node>::const_iterator it = d_mv[i].find(n);
- Assert(it != d_mv[i].end());
- if (it->second.isConst())
- {
- printRationalApprox(c, it->second, prec);
- }
- else
- {
- Trace(c) << "?";
- }
- Trace(c) << (i == 1 ? " [actual: " : " ]");
+ printRationalApprox(c, aval, prec);
+ }
+ else
+ {
+ Trace(c) << "?";
+ }
+ Trace(c) << " [actual: ";
+ const Node& cval = d_concreteModelCache.at(n);
+ if (cval.isConst())
+ {
+ printRationalApprox(c, cval, prec);
+ }
+ else
+ {
+ Trace(c) << "?";
}
- Trace(c) << std::endl;
+ Trace(c) << " ]" << std::endl;
}
}
@@ -1316,13 +1244,12 @@ void NlModel::getModelValueRepair(
// special kind approximation of the form (witness x. x = exact_value).
// Notice that the above term gets rewritten such that the choice function
// is eliminated.
- for (size_t i = 0, num = d_check_model_vars.size(); i < num; i++)
+ for (size_t i = 0; i < d_substitutions.size(); ++i)
{
- Node v = d_check_model_vars[i];
- Node s = d_check_model_subs[i];
// overwrite
- arithModel[v] = s;
- Trace("nl-model") << v << " solved is " << s << std::endl;
+ arithModel[d_substitutions.d_vars[i]] = d_substitutions.d_subs[i];
+ Trace("nl-model") << d_substitutions.d_vars[i] << " solved is "
+ << d_substitutions.d_subs[i] << std::endl;
}
// multiplication terms should not be given values; their values are
@@ -1341,6 +1268,49 @@ void NlModel::getModelValueRepair(
}
}
+Node NlModel::getValueInternal(TNode n)
+{
+ if (n.isConst())
+ {
+ return n;
+ }
+ if (auto it = d_arithVal.find(n); it != d_arithVal.end())
+ {
+ AlwaysAssert(it->second.isConst());
+ return it->second;
+ }
+ // It is unconstrained in the model, return 0. We additionally add it
+ // to mapping from the linear solver. This ensures that if the nonlinear
+ // solver assumes that n = 0, then this assumption is recorded in the overall
+ // model.
+ d_arithVal[n] = d_zero;
+ return d_zero;
+}
+
+bool NlModel::hasAssignment(Node v) const
+{
+ if (d_check_model_bounds.find(v) != d_check_model_bounds.end())
+ {
+ return true;
+ }
+ if (d_check_model_witnesses.find(v) != d_check_model_witnesses.end())
+ {
+ return true;
+ }
+ return (d_substitutions.contains(v));
+}
+
+bool NlModel::hasLinearModelValue(TNode v, Node& val) const
+{
+ auto it = d_arithVal.find(v);
+ if (it != d_arithVal.end())
+ {
+ val = it->second;
+ return true;
+ }
+ return false;
+}
+
} // namespace nl
} // namespace arith
} // namespace theory
diff --git a/src/theory/arith/nl/nl_model.h b/src/theory/arith/nl/nl_model.h
index 526a93934..b3b841eab 100644
--- a/src/theory/arith/nl/nl_model.h
+++ b/src/theory/arith/nl/nl_model.h
@@ -22,6 +22,7 @@
#include "expr/kind.h"
#include "expr/node.h"
+#include "expr/subs.h"
namespace cvc5 {
@@ -59,7 +60,7 @@ class NlModel
* where m is the model of the theory of arithmetic. This method resets the
* cache of computed model values.
*/
- void reset(TheoryModel* m, std::map<Node, Node>& arithModel);
+ void reset(TheoryModel* m, const std::map<Node, Node>& arithModel);
/**
* This method is called when the non-linear arithmetic solver restarts
* its computation of lemmas and models during a last call effort check.
@@ -87,9 +88,9 @@ class NlModel
* whereas:
* computeModelValue( a*b, false ) = 5
*/
- Node computeConcreteModelValue(Node n);
- Node computeAbstractModelValue(Node n);
- Node computeModelValue(Node n, bool isConcrete);
+ Node computeConcreteModelValue(TNode n);
+ Node computeAbstractModelValue(TNode n);
+ Node computeModelValue(TNode n, bool isConcrete);
/**
* Compare arithmetic terms i and j based an ordering.
@@ -101,10 +102,10 @@ class NlModel
* otherwise, we consider their abstract model values. For definitions of
* concrete vs abstract model values, see NlModel::computeModelValue.
*
- * If isAbsolute is true, we compare the absolute value of thee above
+ * If isAbsolute is true, we compare the absolute value of the above
* values.
*/
- int compare(Node i, Node j, bool isConcrete, bool isAbsolute);
+ int compare(TNode i, TNode j, bool isConcrete, bool isAbsolute);
/**
* Compare arithmetic terms i and j based an ordering.
*
@@ -113,38 +114,31 @@ class NlModel
*
* If isAbsolute is true, we compare the absolute value of i and j
*/
- int compareValue(Node i, Node j, bool isAbsolute) const;
+ int compareValue(TNode i, TNode j, bool isAbsolute) const;
//------------------------------ recording model substitutions and bounds
/**
* Adds the model substitution v -> s. This applies the substitution
- * { v -> s } to each term in d_check_model_subs and adds v,s to
- * d_check_model_vars and d_check_model_subs respectively.
+ * { v -> s } to each term in d_substitutions and then adds v,s to
+ * d_substitutions.
* If this method returns false, then the substitution v -> s is inconsistent
* with the current substitution and bounds.
*/
- bool addCheckModelSubstitution(TNode v, TNode s);
+ bool addSubstitution(TNode v, TNode s);
/**
* Adds the bound x -> < l, u > to the map above, and records the
* approximation ( x, l <= x <= u ) in the model. This method returns false
* if the bound is inconsistent with the current model substitution or
* bounds.
*/
- bool addCheckModelBound(TNode v, TNode l, TNode u);
+ bool addBound(TNode v, TNode l, TNode u);
/**
* Adds a model witness v -> w to the underlying theory model.
* The witness should only contain a single variable v and evaluate to true
* for exactly one value of v. The variable v is then (implicitly,
* declaratively) assigned to this single value that satisfies the witness w.
*/
- bool addCheckModelWitness(TNode v, TNode w);
- /**
- * Have we assigned v in the current checkModel(...) call?
- *
- * This method returns true if variable v is in the domain of
- * d_check_model_bounds or if it occurs in d_check_model_vars.
- */
- bool hasCheckModelAssignment(Node v) const;
+ bool addWitness(TNode v, TNode w);
/**
* Checks the current model based on solving for equalities, and using error
* bounds on the Taylor approximation.
@@ -198,21 +192,53 @@ class NlModel
bool witnessToValue);
private:
+ /** Cache for concrete model values */
+ std::map<Node, Node> d_concreteModelCache;
+ /** Cache for abstract model values */
+ std::map<Node, Node> d_abstractModelCache;
+
/** The current model */
TheoryModel* d_model;
+
+ /**
+ * The values that the arithmetic theory solver assigned in the model. This
+ * corresponds to the set of equalities that linear solver (via TheoryArith)
+ * is currently sending to TheoryModel during collectModelValues, plus
+ * additional entries x -> 0 for variables that were unassigned by the linear
+ * solver.
+ */
+ std::map<Node, Node> d_arithVal;
+
+ /**
+ * A substitution from variables that appear in assertions to a solved form
+ * term.
+ */
+ Subs d_substitutions;
+
/** Get the model value of n from the model object above */
- Node getValueInternal(Node n);
- /** Does the equality engine of the model have term n? */
- bool hasTerm(Node n) const;
- /** Get the representative of n in the model */
- Node getRepresentative(Node n) const;
+ Node getValueInternal(TNode n);
+
+ /**
+ * Have we assigned v in the current checkModel(...) call?
+ *
+ * This method returns true if variable v is in the domain of
+ * d_check_model_bounds or if it occurs in d_substitutions.
+ */
+ bool hasAssignment(Node v) const;
+
+ /**
+ * Checks whether we have a linear model value for v, i.e. whether v is
+ * contained in d_arithVal. If so, we also store the value that v is mapped
+ * to in val.
+ */
+ bool hasLinearModelValue(TNode v, Node& val) const;
//---------------------------check model
/**
* This method is used during checkModel(...). It takes as input an
* equality eq. If it returns true, then eq is correct-by-construction based
* on the information stored in our model representation (see
- * d_check_model_vars, d_check_model_subs, d_check_model_bounds), and eq
+ * d_substitutions, d_check_model_bounds), and eq
* is added to d_check_model_solved. The equality eq may involve any
* number of variables, and monomials of arbitrary degree. If this method
* returns false, then we did not show that the equality was true in the
@@ -268,29 +294,6 @@ class NlModel
Node d_false;
Node d_null;
/**
- * The values that the arithmetic theory solver assigned in the model. This
- * corresponds to the set of equalities that linear solver (via TheoryArith)
- * is currently sending to TheoryModel during collectModelValues, plus
- * additional entries x -> 0 for variables that were unassigned by the linear
- * solver.
- */
- std::map<Node, Node> d_arithVal;
- /**
- * cache of model values
- *
- * Stores the the concrete/abstract model values. This is a cache of the
- * computeModelValue method.
- */
- std::map<Node, Node> d_mv[2];
- /**
- * A substitution from variables that appear in assertions to a solved form
- * term. These vectors are ordered in the form:
- * x_1 -> t_1 ... x_n -> t_n
- * where x_i is not in the free variables of t_j for j>=i.
- */
- std::vector<Node> d_check_model_vars;
- std::vector<Node> d_check_model_subs;
- /**
* lower and upper bounds for check model
*
* For each term t in the domain of this map, if this stores the pair
diff --git a/src/theory/arith/nl/nonlinear_extension.cpp b/src/theory/arith/nl/nonlinear_extension.cpp
index f80717b57..207907fcc 100644
--- a/src/theory/arith/nl/nonlinear_extension.cpp
+++ b/src/theory/arith/nl/nonlinear_extension.cpp
@@ -96,17 +96,16 @@ void NonlinearExtension::processSideEffect(const NlLemma& se)
void NonlinearExtension::computeRelevantAssertions(
const std::vector<Node>& assertions, std::vector<Node>& keep)
{
- Trace("nl-ext-rlv") << "Compute relevant assertions..." << std::endl;
- Valuation v = d_containing.getValuation();
+ const Valuation& v = d_containing.getValuation();
for (const Node& a : assertions)
{
if (v.isRelevant(a))
{
- keep.push_back(a);
+ keep.emplace_back(a);
}
}
- Trace("nl-ext-rlv") << "...keep " << keep.size() << "/" << assertions.size()
- << " assertions" << std::endl;
+ Trace("nl-ext-rlv") << "...relevant assertions: " << keep.size() << "/"
+ << assertions.size() << std::endl;
}
void NonlinearExtension::getAssertions(std::vector<Node>& assertions)
diff --git a/src/theory/arith/nl/poly_conversion.h b/src/theory/arith/nl/poly_conversion.h
index db64320d5..dddde3c0f 100644
--- a/src/theory/arith/nl/poly_conversion.h
+++ b/src/theory/arith/nl/poly_conversion.h
@@ -98,8 +98,8 @@ std::pair<poly::Polynomial, poly::SignCondition> as_poly_constraint(
/**
* Transforms a real algebraic number to a node suitable for putting it into a
* model. The resulting node can be either a constant (suitable for
- * addCheckModelSubstitution) or a witness term (suitable for
- * addCheckModelWitness).
+ * addSubstitution) or a witness term (suitable for
+ * addWitness).
*/
Node ran_to_node(const RealAlgebraicNumber& ran, const Node& ran_variable);
diff --git a/src/theory/arith/nl/transcendental/transcendental_solver.cpp b/src/theory/arith/nl/transcendental/transcendental_solver.cpp
index c7bb14b3f..978823a22 100644
--- a/src/theory/arith/nl/transcendental/transcendental_solver.cpp
+++ b/src/theory/arith/nl/transcendental/transcendental_solver.cpp
@@ -83,12 +83,10 @@ void TranscendentalSolver::initLastCall(const std::vector<Node>& xts)
bool TranscendentalSolver::preprocessAssertionsCheckModel(
std::vector<Node>& assertions)
{
- std::vector<Node> pvars;
- std::vector<Node> psubs;
- for (const std::pair<const Node, Node>& tb : d_tstate.d_trMaster)
+ Subs subs;
+ for (const auto& sub : d_tstate.d_trMaster)
{
- pvars.push_back(tb.first);
- psubs.push_back(tb.second);
+ subs.add(sub.first, sub.second);
}
// initialize representation of assertions
@@ -97,9 +95,9 @@ bool TranscendentalSolver::preprocessAssertionsCheckModel(
{
Node pa = a;
- if (!pvars.empty())
+ if (!subs.empty())
{
- pa = arithSubstitute(pa, pvars, psubs);
+ pa = arithSubstitute(pa, subs);
pa = Rewriter::rewrite(pa);
}
if (!pa.isConst() || !pa.getConst<bool>())
@@ -145,8 +143,8 @@ bool TranscendentalSolver::preprocessAssertionsCheckModel(
Trace("nl-ext-cm")
<< "...bound for " << stf << " : [" << bounds.first << ", "
<< bounds.second << "]" << std::endl;
- success = d_tstate.d_model.addCheckModelBound(
- stf, bounds.first, bounds.second);
+ success =
+ d_tstate.d_model.addBound(stf, bounds.first, bounds.second);
}
}
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback