diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2019-12-11 11:58:53 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-11 11:58:53 -0600 |
commit | 23eb6c0ab05b6607c14ee33b5c0101381aa0bc41 (patch) | |
tree | ef91882b2bf83f66daa324428b8449bea146020a /src | |
parent | b12f67c710d359cd57d09dbff67f13bf26e10834 (diff) |
Do not substitute beneath arithmetic terms in the non-linear solver (#3324)
Diffstat (limited to 'src')
-rw-r--r-- | src/theory/arith/arith_utilities.cpp | 78 | ||||
-rw-r--r-- | src/theory/arith/arith_utilities.h | 10 | ||||
-rw-r--r-- | src/theory/arith/nl_model.cpp | 19 | ||||
-rw-r--r-- | src/theory/arith/nonlinear_extension.cpp | 3 |
4 files changed, 97 insertions, 13 deletions
diff --git a/src/theory/arith/arith_utilities.cpp b/src/theory/arith/arith_utilities.cpp index 3d3078d99..65aaceb80 100644 --- a/src/theory/arith/arith_utilities.cpp +++ b/src/theory/arith/arith_utilities.cpp @@ -191,6 +191,84 @@ void printRationalApprox(const char* c, Node cr, unsigned prec) } } +Node arithSubstitute(Node n, std::vector<Node>& vars, std::vector<Node>& subs) +{ + Assert(vars.size() == subs.size()); + NodeManager* nm = NodeManager::currentNM(); + std::unordered_map<TNode, Node, TNodeHashFunction> visited; + std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it; + std::vector<Node>::iterator itv; + std::vector<TNode> visit; + TNode cur; + Kind ck; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + + if (it == visited.end()) + { + visited[cur] = Node::null(); + ck = cur.getKind(); + itv = std::find(vars.begin(), vars.end(), cur); + if (itv != vars.end()) + { + visited[cur] = subs[std::distance(vars.begin(), itv)]; + } + else if (cur.getNumChildren() == 0) + { + visited[cur] = cur; + } + else + { + TheoryId ctid = theory::kindToTheoryId(ck); + if (ctid != THEORY_ARITH && ctid != THEORY_BOOL + && ctid != THEORY_BUILTIN) + { + // do not traverse beneath applications that belong to another theory + visited[cur] = cur; + } + else + { + visit.push_back(cur); + for (const Node& cn : cur) + { + visit.push_back(cn); + } + } + } + } + else if (it->second.isNull()) + { + Node ret = cur; + bool childChanged = false; + std::vector<Node> children; + if (cur.getMetaKind() == kind::metakind::PARAMETERIZED) + { + children.push_back(cur.getOperator()); + } + for (const Node& cn : cur) + { + it = visited.find(cn); + Assert(it != visited.end()); + Assert(!it->second.isNull()); + childChanged = childChanged || cn != it->second; + children.push_back(it->second); + } + if (childChanged) + { + ret = nm->mkNode(cur.getKind(), children); + } + visited[cur] = ret; + } + } while (!visit.empty()); + Assert(visited.find(n) != visited.end()); + Assert(!visited.find(n)->second.isNull()); + return visited[n]; +} + } // namespace arith } // namespace theory } // namespace CVC4 diff --git a/src/theory/arith/arith_utilities.h b/src/theory/arith/arith_utilities.h index d737fefeb..f87a908b4 100644 --- a/src/theory/arith/arith_utilities.h +++ b/src/theory/arith/arith_utilities.h @@ -325,6 +325,16 @@ Node getApproximateConstant(Node c, bool isLower, unsigned prec); /** print rational approximation of cr with precision prec on trace c */ void printRationalApprox(const char* c, Node cr, unsigned prec = 5); +/** Arithmetic substitute + * + * This computes the substitution n { vars -> subs }, but with the caveat + * that subterms of n that belong to a theory other than arithmetic are + * not traversed. In other words, terms that belong to other theories are + * treated as atomic variables. For example: + * (5*f(x) + 7*x ){ x -> 3 } returns 5*f(x) + 7*3. + */ +Node arithSubstitute(Node n, std::vector<Node>& vars, std::vector<Node>& subs); + }/* CVC4::theory::arith namespace */ }/* CVC4::theory namespace */ }/* CVC4 namespace */ diff --git a/src/theory/arith/nl_model.cpp b/src/theory/arith/nl_model.cpp index fe756e5f7..3274867bb 100644 --- a/src/theory/arith/nl_model.cpp +++ b/src/theory/arith/nl_model.cpp @@ -284,10 +284,7 @@ bool NlModel::checkModel(const std::vector<Node>& assertions, // apply the substitution to a if (!d_check_model_vars.empty()) { - av = av.substitute(d_check_model_vars.begin(), - d_check_model_vars.end(), - d_check_model_subs.begin(), - d_check_model_subs.end()); + av = arithSubstitute(av, d_check_model_vars, d_check_model_subs); av = Rewriter::rewrite(av); } // simple check literal @@ -360,10 +357,14 @@ bool NlModel::addCheckModelSubstitution(TNode v, TNode s) return false; } } + std::vector<Node> varsTmp; + varsTmp.push_back(v); + std::vector<Node> subsTmp; + subsTmp.push_back(s); for (unsigned i = 0, size = d_check_model_subs.size(); i < size; i++) { Node ms = d_check_model_subs[i]; - Node mss = ms.substitute(v, s); + Node mss = arithSubstitute(ms, varsTmp, subsTmp); if (mss != ms) { mss = Rewriter::rewrite(mss); @@ -430,10 +431,7 @@ bool NlModel::solveEqualitySimple(Node eq, Node seq = eq; if (!d_check_model_vars.empty()) { - seq = eq.substitute(d_check_model_vars.begin(), - d_check_model_vars.end(), - d_check_model_subs.begin(), - d_check_model_subs.end()); + seq = arithSubstitute(eq, d_check_model_vars, d_check_model_subs); seq = Rewriter::rewrite(seq); if (seq.isConst()) { @@ -866,8 +864,7 @@ bool NlModel::simpleCheckModelLit(Node lit) for (unsigned r = 0; r < 2; r++) { qsubs.push_back(boundn[r]); - Node ts = t.substitute( - qvars.begin(), qvars.end(), qsubs.begin(), qsubs.end()); + Node ts = arithSubstitute(t, qvars, qsubs); tcmpn[r] = Rewriter::rewrite(ts); qsubs.pop_back(); } diff --git a/src/theory/arith/nonlinear_extension.cpp b/src/theory/arith/nonlinear_extension.cpp index 6e8e7623d..ff2ec412b 100644 --- a/src/theory/arith/nonlinear_extension.cpp +++ b/src/theory/arith/nonlinear_extension.cpp @@ -772,8 +772,7 @@ bool NonlinearExtension::checkModel(const std::vector<Node>& assertions, Node pa = a; if (!pvars.empty()) { - pa = - pa.substitute(pvars.begin(), pvars.end(), psubs.begin(), psubs.end()); + pa = arithSubstitute(pa, pvars, psubs); pa = Rewriter::rewrite(pa); } if (!pa.isConst() || !pa.getConst<bool>()) |