summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2018-09-05 12:41:47 -0500
committerGitHub <noreply@github.com>2018-09-05 12:41:47 -0500
commit4e4068f1d29ddc1ffe0bde8e6f2cf3094fd6bd40 (patch)
tree49c496d78643921bbb0cdbf436b49965ce5a5161 /src
parent1752aeb263a986bf437bb04029474b41987450d2 (diff)
Finer-grained inference of substitutions in incremental mode (#2403)
Diffstat (limited to 'src')
-rw-r--r--src/expr/node_algorithm.cpp36
-rw-r--r--src/expr/node_algorithm.h19
-rw-r--r--src/smt/smt_engine.cpp95
3 files changed, 121 insertions, 29 deletions
diff --git a/src/expr/node_algorithm.cpp b/src/expr/node_algorithm.cpp
index 5443a3a2a..9240e4a8e 100644
--- a/src/expr/node_algorithm.cpp
+++ b/src/expr/node_algorithm.cpp
@@ -166,5 +166,41 @@ bool hasFreeVar(TNode n)
return false;
}
+void getSymbols(TNode n, std::unordered_set<Node, NodeHashFunction>& syms)
+{
+ std::unordered_set<TNode, TNodeHashFunction> visited;
+ getSymbols(n, syms);
+}
+
+void getSymbols(TNode n,
+ std::unordered_set<Node, NodeHashFunction>& syms,
+ std::unordered_set<TNode, TNodeHashFunction>& visited)
+{
+ std::vector<TNode> visit;
+ TNode cur;
+ visit.push_back(n);
+ do
+ {
+ cur = visit.back();
+ visit.pop_back();
+ if (visited.find(cur) == visited.end())
+ {
+ visited.insert(cur);
+ if (cur.isVar() && cur.getKind() != kind::BOUND_VARIABLE)
+ {
+ syms.insert(cur);
+ }
+ if (cur.hasOperator())
+ {
+ visit.push_back(cur.getOperator());
+ }
+ for (TNode cn : cur)
+ {
+ visit.push_back(cn);
+ }
+ }
+ } while (!visit.empty());
+}
+
} // namespace expr
} // namespace CVC4
diff --git a/src/expr/node_algorithm.h b/src/expr/node_algorithm.h
index 61e81c4c2..7453bc292 100644
--- a/src/expr/node_algorithm.h
+++ b/src/expr/node_algorithm.h
@@ -39,20 +39,33 @@ namespace expr {
bool hasSubterm(TNode n, TNode t, bool strict = false);
/**
- * Returns true iff the node n contains a bound variable. This bound
- * variable may or may not be free.
+ * Returns true iff the node n contains a bound variable, that is a node of
+ * kind BOUND_VARIABLE. This bound variable may or may not be free.
* @param n The node under investigation
* @return true iff this node contains a bound variable
*/
bool hasBoundVar(TNode n);
/**
- * Returns true iff the node n contains a free variable.
+ * Returns true iff the node n contains a free variable, that is, a node
+ * of kind BOUND_VARIABLE that is not bound in n.
* @param n The node under investigation
* @return true iff this node contains a free variable.
*/
bool hasFreeVar(TNode n);
+/**
+ * For term n, this function collects the symbols that occur as a subterms
+ * of n. A symbol is a variable that does not have kind BOUND_VARIABLE.
+ * @param n The node under investigation
+ * @param syms The set which the symbols of n are added to
+ */
+void getSymbols(TNode n, std::unordered_set<Node, NodeHashFunction>& syms);
+/** Same as above, with a visited cache */
+void getSymbols(TNode n,
+ std::unordered_set<Node, NodeHashFunction>& syms,
+ std::unordered_set<TNode, TNodeHashFunction>& visited);
+
} // namespace expr
} // namespace CVC4
diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp
index cdd5ab3e0..17edaad41 100644
--- a/src/smt/smt_engine.cpp
+++ b/src/smt/smt_engine.cpp
@@ -45,6 +45,7 @@
#include "expr/kind.h"
#include "expr/metakind.h"
#include "expr/node.h"
+#include "expr/node_algorithm.h"
#include "expr/node_builder.h"
#include "expr/node_self_iterator.h"
#include "options/arith_options.h"
@@ -448,6 +449,7 @@ class SmtEnginePrivate : public NodeManagerListener {
typedef unordered_map<Node, Node, NodeHashFunction> NodeToNodeHashMap;
typedef unordered_map<Node, bool, NodeHashFunction> NodeToBoolHashMap;
+ typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
/**
* Manager for limiting time and abstract resource usage.
@@ -504,6 +506,13 @@ class SmtEnginePrivate : public NodeManagerListener {
SubstitutionMap d_abstractValueMap;
/**
+ * The (user-context-dependent) set of symbols that occur in at least one
+ * assertion in the current user context. This is used by the
+ * nonClausalSimplify pass.
+ */
+ NodeSet d_symsInAssertions;
+
+ /**
* A mapping of all abstract values (actual value |-> abstract) that
* we've handed out. This is necessary to ensure that we give the
* same AbstractValues for the same real constants. Only used if
@@ -545,6 +554,13 @@ class SmtEnginePrivate : public NodeManagerListener {
*/
bool nonClausalSimplify();
+ /** record symbols in assertions
+ *
+ * This method is called when a set of assertions is finalized. It adds
+ * the symbols to d_symsInAssertions that occur in assertions.
+ */
+ void recordSymbolsInAssertions(const std::vector<Node>& assertions);
+
/**
* Helper function to fix up assertion list to restore invariants needed after
* ite removal.
@@ -579,6 +595,7 @@ class SmtEnginePrivate : public NodeManagerListener {
d_assertionsProcessed(smt.d_userContext, false),
d_fakeContext(),
d_abstractValueMap(&d_fakeContext),
+ d_symsInAssertions(smt.d_userContext),
d_abstractValues(),
d_simplifyAssertionsDepth(0),
// d_needsExpandDefs(true), //TODO?
@@ -833,7 +850,6 @@ class SmtEnginePrivate : public NodeManagerListener {
}
}
//------------------------------- end expression names
-
};/* class SmtEnginePrivate */
}/* namespace CVC4::smt */
@@ -3126,35 +3142,42 @@ bool SmtEnginePrivate::nonClausalSimplify() {
<< assertion << endl;
}
- // If in incremental mode, add substitutions to the list of assertions
- if (substs_index > 0)
+ // add substitutions to model, or as assertions if needed (when incremental)
+ TheoryModel* m = d_smt.d_theoryEngine->getModel();
+ Assert(m != nullptr);
+ NodeManager* nm = NodeManager::currentNM();
+ NodeBuilder<> substitutionsBuilder(kind::AND);
+ for (pos = newSubstitutions.begin(); pos != newSubstitutions.end(); ++pos)
{
- NodeBuilder<> substitutionsBuilder(kind::AND);
- substitutionsBuilder << d_assertions[substs_index];
- pos = newSubstitutions.begin();
- for (; pos != newSubstitutions.end(); ++pos) {
- // Add back this substitution as an assertion
- TNode lhs = (*pos).first, rhs = newSubstitutions.apply((*pos).second);
- Node n = NodeManager::currentNM()->mkNode(kind::EQUAL, lhs, rhs);
- substitutionsBuilder << n;
- Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): will notify SAT layer of substitution: " << n << endl;
- }
- if (substitutionsBuilder.getNumChildren() > 1) {
- d_assertions.replace(substs_index,
- Rewriter::rewrite(Node(substitutionsBuilder)));
+ Node lhs = (*pos).first;
+ Node rhs = newSubstitutions.apply((*pos).second);
+ // If using incremental, we must check whether this variable has occurred
+ // before now. If it hasn't we can add this as a substitution.
+ if (substs_index == 0
+ || d_symsInAssertions.find(lhs) == d_symsInAssertions.end())
+ {
+ Trace("simplify")
+ << "SmtEnginePrivate::nonClausalSimplify(): substitute: " << lhs
+ << " " << rhs << endl;
+ m->addSubstitution(lhs, rhs);
}
- } else {
- // If not in incremental mode, must add substitutions to model
- TheoryModel* m = d_smt.d_theoryEngine->getModel();
- if(m != NULL) {
- for(pos = newSubstitutions.begin(); pos != newSubstitutions.end(); ++pos) {
- Node n = (*pos).first;
- Node v = newSubstitutions.apply((*pos).second);
- Trace("model") << "Add substitution : " << n << " " << v << endl;
- m->addSubstitution( n, v );
- }
+ else
+ {
+ // if it has, the substitution becomes an assertion
+ Node eq = nm->mkNode(kind::EQUAL, lhs, rhs);
+ Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
+ "substitute: will notify SAT layer of substitution: "
+ << eq << endl;
+ substitutionsBuilder << eq;
}
}
+ // add to the last assertion if necessary
+ if (substitutionsBuilder.getNumChildren() > 0)
+ {
+ substitutionsBuilder << d_assertions[substs_index];
+ d_assertions.replace(substs_index,
+ Rewriter::rewrite(Node(substitutionsBuilder)));
+ }
NodeBuilder<> learnedBuilder(kind::AND);
Assert(d_assertions.getRealAssertionsEnd() <= d_assertions.size());
@@ -3415,6 +3438,20 @@ void SmtEnginePrivate::collectSkolems(TNode n, set<TNode>& skolemSet, unordered_
cache[n] = true;
}
+void SmtEnginePrivate::recordSymbolsInAssertions(
+ const std::vector<Node>& assertions)
+{
+ std::unordered_set<TNode, TNodeHashFunction> visited;
+ std::unordered_set<Node, NodeHashFunction> syms;
+ for (TNode cn : assertions)
+ {
+ expr::getSymbols(cn, syms, visited);
+ }
+ for (const Node& s : syms)
+ {
+ d_symsInAssertions.insert(s);
+ }
+}
bool SmtEnginePrivate::checkForBadSkolems(TNode n, TNode skolem, unordered_map<Node, bool, NodeHashFunction>& cache)
{
@@ -3831,6 +3868,12 @@ void SmtEnginePrivate::processAssertions() {
Trace("smt-proc") << "SmtEnginePrivate::processAssertions() end" << endl;
dumpAssertions("post-everything", d_assertions);
+ // if incremental, compute which variables are assigned
+ if (options::incrementalSolving())
+ {
+ recordSymbolsInAssertions(d_assertions.ref());
+ }
+
// Push the formula to SAT
{
Chat() << "converting to CNF..." << endl;
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback