summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2019-04-17 16:35:51 -0500
committerGitHub <noreply@github.com>2019-04-17 16:35:51 -0500
commitd9a103f371cd800615b37fa378ad9d8b7681ee1c (patch)
tree95d338f7e6ca8e760adaaf154f48a190008b6909
parentd0c44a9e048558887ab75aaec4c493696c67b456 (diff)
Cache explanations in the equality engine (#2937)
-rw-r--r--src/theory/uf/equality_engine.cpp157
-rw-r--r--src/theory/uf/equality_engine.h32
2 files changed, 145 insertions, 44 deletions
diff --git a/src/theory/uf/equality_engine.cpp b/src/theory/uf/equality_engine.cpp
index d1fc8341c..148a5e427 100644
--- a/src/theory/uf/equality_engine.cpp
+++ b/src/theory/uf/equality_engine.cpp
@@ -929,9 +929,9 @@ std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const {
void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
std::vector<TNode>& equalities,
EqProof* eqp) const {
- Debug("equality") << d_name << "::eq::explainEquality(" << t1 << ", " << t2
- << ", " << (polarity ? "true" : "false") << ")"
- << ", proof = " << (eqp ? "ON" : "OFF") << std::endl;
+ Debug("pf::ee") << d_name << "::eq::explainEquality(" << t1 << ", " << t2
+ << ", " << (polarity ? "true" : "false") << ")"
+ << ", proof = " << (eqp ? "ON" : "OFF") << std::endl;
// The terms must be there already
Assert(hasTerm(t1) && hasTerm(t2));;
@@ -940,9 +940,10 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
EqualityNodeId t1Id = getNodeId(t1);
EqualityNodeId t2Id = getNodeId(t2);
+ std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*> cache;
if (polarity) {
// Get the explanation
- getExplanation(t1Id, t2Id, equalities, eqp);
+ getExplanation(t1Id, t2Id, equalities, cache, eqp);
} else {
if (eqp) {
eqp->d_id = eq::MERGED_THROUGH_TRANS;
@@ -964,12 +965,15 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
eqpc = std::make_shared<EqProof>();
}
- getExplanation(toExplain.first, toExplain.second, equalities, eqpc.get());
+ getExplanation(
+ toExplain.first, toExplain.second, equalities, cache, eqpc.get());
if (eqpc) {
- Debug("pf::ee") << "Child proof is:" << std::endl;
- eqpc->debug_print("pf::ee", 1);
-
+ if (Debug.isOn("pf::ee"))
+ {
+ Debug("pf::ee") << "Child proof is:" << std::endl;
+ eqpc->debug_print("pf::ee", 1);
+ }
if (eqpc->d_id == eq::MERGED_THROUGH_TRANS) {
std::vector<std::shared_ptr<EqProof>> orderedChildren;
bool nullCongruenceFound = false;
@@ -987,8 +991,13 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
if (nullCongruenceFound) {
eqpc->d_children = orderedChildren;
- Debug("pf::ee") << "Child proof's children have been reordered. It is now:" << std::endl;
- eqpc->debug_print("pf::ee", 1);
+ if (Debug.isOn("pf::ee"))
+ {
+ Debug("pf::ee")
+ << "Child proof's children have been reordered. It is now:"
+ << std::endl;
+ eqpc->debug_print("pf::ee", 1);
+ }
}
}
@@ -1011,8 +1020,11 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
*eqp = *temp;
}
- Debug("pf::ee") << "Disequality explanation final proof: " << std::endl;
- eqp->debug_print("pf::ee", 1);
+ if (Debug.isOn("pf::ee"))
+ {
+ Debug("pf::ee") << "Disequality explanation final proof: " << std::endl;
+ eqp->debug_print("pf::ee", 1);
+ }
}
}
}
@@ -1024,15 +1036,51 @@ void EqualityEngine::explainPredicate(TNode p, bool polarity,
<< std::endl;
// Must have the term
Assert(hasTerm(p));
+ std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*> cache;
// Get the explanation
- getExplanation(getNodeId(p), polarity ? d_trueId : d_falseId, assertions,
- eqp);
+ getExplanation(
+ getNodeId(p), polarity ? d_trueId : d_falseId, assertions, cache, eqp);
}
-void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
- std::vector<TNode>& equalities,
- EqProof* eqp) const {
- Debug("equality") << d_name << "::eq::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl;
+void EqualityEngine::getExplanation(
+ EqualityNodeId t1Id,
+ EqualityNodeId t2Id,
+ std::vector<TNode>& equalities,
+ std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*>& cache,
+ EqProof* eqp) const
+{
+ Trace("eq-exp") << d_name << "::eq::getExplanation(" << d_nodes[t1Id] << ","
+ << d_nodes[t2Id] << ") size = " << cache.size() << std::endl;
+
+ // We order the ids, since explaining t1 = t2 is the same as explaining
+ // t2 = t1.
+ std::pair<EqualityNodeId, EqualityNodeId> cacheKey = std::minmax(t1Id, t2Id);
+ std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*>::iterator it =
+ cache.find(cacheKey);
+ if (it != cache.end())
+ {
+ // copy one level
+ if (eqp)
+ {
+ if (it->second)
+ {
+ eqp->d_node = it->second->d_node;
+ eqp->d_id = it->second->d_id;
+ eqp->d_children.insert(eqp->d_children.end(),
+ it->second->d_children.begin(),
+ it->second->d_children.end());
+ }
+ else
+ {
+ // We may have cached null in its place, create the trivial proof now.
+ Assert(d_nodes[t1Id] == d_nodes[t2Id]);
+ Assert(eqp->d_id == MERGED_THROUGH_REFLEXIVITY);
+ eqp->d_node = d_nodes[t1Id];
+ }
+ }
+ return;
+ }
+ cache[cacheKey] = eqp;
// We can only explain the nodes that got merged
#ifdef CVC4_ASSERTIONS
@@ -1136,11 +1184,11 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
Debug("equality") << "Explaining left hand side equalities" << std::endl;
std::shared_ptr<EqProof> eqpc1 =
eqpc ? std::make_shared<EqProof>() : nullptr;
- getExplanation(f1.a, f2.a, equalities, eqpc1.get());
+ getExplanation(f1.a, f2.a, equalities, cache, eqpc1.get());
Debug("equality") << "Explaining right hand side equalities" << std::endl;
std::shared_ptr<EqProof> eqpc2 =
eqpc ? std::make_shared<EqProof>() : nullptr;
- getExplanation(f1.b, f2.b, equalities, eqpc2.get());
+ getExplanation(f1.b, f2.b, equalities, cache, eqpc2.get());
if( eqpc ){
eqpc->d_children.push_back( eqpc1 );
eqpc->d_children.push_back( eqpc2 );
@@ -1185,7 +1233,7 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
Debug("equality") << push;
std::shared_ptr<EqProof> eqpc1 =
eqpc ? std::make_shared<EqProof>() : nullptr;
- getExplanation(eq.a, eq.b, equalities, eqpc1.get());
+ getExplanation(eq.a, eq.b, equalities, cache, eqpc1.get());
if( eqpc ){
eqpc->d_children.push_back( eqpc1 );
}
@@ -1211,13 +1259,20 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
Assert(isConstant(childId));
std::shared_ptr<EqProof> eqpcc =
eqpc ? std::make_shared<EqProof>() : nullptr;
- getExplanation(childId, getEqualityNode(childId).getFind(),
- equalities, eqpcc.get());
+ getExplanation(childId,
+ getEqualityNode(childId).getFind(),
+ equalities,
+ cache,
+ eqpcc.get());
if( eqpc ) {
eqpc->d_children.push_back( eqpcc );
-
- Debug("pf::ee") << "MERGED_THROUGH_CONSTANTS. Dumping the child proof" << std::endl;
- eqpc->debug_print("pf::ee", 1);
+ if (Debug.isOn("pf::ee"))
+ {
+ Debug("pf::ee")
+ << "MERGED_THROUGH_CONSTANTS. Dumping the child proof"
+ << std::endl;
+ eqpc->debug_print("pf::ee", 1);
+ }
}
}
@@ -1255,7 +1310,6 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
}
eqpc->d_id = reasonType;
}
-
equalities.push_back(reason);
break;
}
@@ -1288,8 +1342,10 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
eqp->d_children.insert( eqp->d_children.end(), eqp_trans.begin(), eqp_trans.end() );
eqp->d_node = NodeManager::currentNM()->mkNode(kind::EQUAL, d_nodes[t1Id], d_nodes[t2Id]);
}
-
- eqp->debug_print("pf::ee", 1);
+ if (Debug.isOn("pf::ee"))
+ {
+ eqp->debug_print("pf::ee", 1);
+ }
}
// Done
@@ -2236,27 +2292,48 @@ bool EqClassIterator::isFinished() const {
}
void EqProof::debug_print(const char* c, unsigned tb, PrettyPrinter* prettyPrinter) const {
- for(unsigned i=0; i<tb; i++) { Debug( c ) << " "; }
+ std::stringstream ss;
+ debug_print(ss, tb, prettyPrinter);
+ Debug(c) << ss.str();
+}
+void EqProof::debug_print(std::ostream& os,
+ unsigned tb,
+ PrettyPrinter* prettyPrinter) const
+{
+ for (unsigned i = 0; i < tb; i++)
+ {
+ os << " ";
+ }
if (prettyPrinter)
- Debug( c ) << prettyPrinter->printTag(d_id);
+ {
+ os << prettyPrinter->printTag(d_id);
+ }
else
- Debug( c ) << d_id;
+ {
+ os << d_id;
+ }
- Debug( c ) << "(";
+ os << "(";
if( !d_children.empty() || !d_node.isNull() ){
if( !d_node.isNull() ){
- Debug( c ) << std::endl;
- for( unsigned i=0; i<tb+1; i++ ) { Debug( c ) << " "; }
- Debug( c ) << d_node;
+ os << std::endl;
+ for (unsigned i = 0; i < tb + 1; i++)
+ {
+ os << " ";
+ }
+ os << d_node;
}
for( unsigned i=0; i<d_children.size(); i++ ){
- if( i>0 || !d_node.isNull() ) Debug( c ) << ",";
- Debug( c ) << std::endl;
- d_children[i]->debug_print( c, tb+1, prettyPrinter );
+ if (i > 0 || !d_node.isNull())
+ {
+ os << ",";
+ }
+ os << std::endl;
+ d_children[i]->debug_print(os, tb + 1, prettyPrinter);
}
}
- Debug( c ) << ")" << std::endl;
+ os << ")" << std::endl;
}
} // Namespace uf
diff --git a/src/theory/uf/equality_engine.h b/src/theory/uf/equality_engine.h
index b93ff6d6d..73d8bd4e9 100644
--- a/src/theory/uf/equality_engine.h
+++ b/src/theory/uf/equality_engine.h
@@ -516,11 +516,24 @@ private:
bool d_inPropagate;
/**
- * Get an explanation of the equality t1 = t2. Returns the asserted equalities that
- * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere
- * else.
+ * Get an explanation of the equality t1 = t2. Returns the asserted equalities
+ * that imply t1 = t2. Returns TNodes as the assertion equalities should be
+ * hashed somewhere else.
+ *
+ * This call refers to terms t1 and t2 by their ids t1Id and t2Id.
+ *
+ * If eqp is non-null, then this method populates eqp's information and
+ * children such that it is a proof of t1 = t2.
+ *
+ * We cache results of this call in cache, where cache[t1Id][t2Id] stores
+ * a proof of t1 = t2.
*/
- void getExplanation(EqualityEdgeId t1Id, EqualityNodeId t2Id, std::vector<TNode>& equalities, EqProof* eqp) const;
+ void getExplanation(
+ EqualityEdgeId t1Id,
+ EqualityNodeId t2Id,
+ std::vector<TNode>& equalities,
+ std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*>& cache,
+ EqProof* eqp) const;
/**
* Print the equality graph.
@@ -941,8 +954,19 @@ public:
unsigned d_id;
Node d_node;
std::vector<std::shared_ptr<EqProof>> d_children;
+ /**
+ * Debug print this proof on debug trace c with tabulation tb and pretty
+ * printer prettyPrinter.
+ */
void debug_print(const char* c, unsigned tb = 0,
PrettyPrinter* prettyPrinter = nullptr) const;
+ /**
+ * Debug print this proof on output stream os with tabulation tb and pretty
+ * printer prettyPrinter.
+ */
+ void debug_print(std::ostream& os,
+ unsigned tb = 0,
+ PrettyPrinter* prettyPrinter = nullptr) const;
};/* class EqProof */
} // Namespace eq
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback