diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2019-04-17 16:35:51 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-04-17 16:35:51 -0500 |
commit | d9a103f371cd800615b37fa378ad9d8b7681ee1c (patch) | |
tree | 95d338f7e6ca8e760adaaf154f48a190008b6909 /src/theory/uf/equality_engine.cpp | |
parent | d0c44a9e048558887ab75aaec4c493696c67b456 (diff) |
Cache explanations in the equality engine (#2937)
Diffstat (limited to 'src/theory/uf/equality_engine.cpp')
-rw-r--r-- | src/theory/uf/equality_engine.cpp | 157 |
1 files changed, 117 insertions, 40 deletions
diff --git a/src/theory/uf/equality_engine.cpp b/src/theory/uf/equality_engine.cpp index d1fc8341c..148a5e427 100644 --- a/src/theory/uf/equality_engine.cpp +++ b/src/theory/uf/equality_engine.cpp @@ -929,9 +929,9 @@ std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const { void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, std::vector<TNode>& equalities, EqProof* eqp) const { - Debug("equality") << d_name << "::eq::explainEquality(" << t1 << ", " << t2 - << ", " << (polarity ? "true" : "false") << ")" - << ", proof = " << (eqp ? "ON" : "OFF") << std::endl; + Debug("pf::ee") << d_name << "::eq::explainEquality(" << t1 << ", " << t2 + << ", " << (polarity ? "true" : "false") << ")" + << ", proof = " << (eqp ? "ON" : "OFF") << std::endl; // The terms must be there already Assert(hasTerm(t1) && hasTerm(t2));; @@ -940,9 +940,10 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, EqualityNodeId t1Id = getNodeId(t1); EqualityNodeId t2Id = getNodeId(t2); + std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*> cache; if (polarity) { // Get the explanation - getExplanation(t1Id, t2Id, equalities, eqp); + getExplanation(t1Id, t2Id, equalities, cache, eqp); } else { if (eqp) { eqp->d_id = eq::MERGED_THROUGH_TRANS; @@ -964,12 +965,15 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, eqpc = std::make_shared<EqProof>(); } - getExplanation(toExplain.first, toExplain.second, equalities, eqpc.get()); + getExplanation( + toExplain.first, toExplain.second, equalities, cache, eqpc.get()); if (eqpc) { - Debug("pf::ee") << "Child proof is:" << std::endl; - eqpc->debug_print("pf::ee", 1); - + if (Debug.isOn("pf::ee")) + { + Debug("pf::ee") << "Child proof is:" << std::endl; + eqpc->debug_print("pf::ee", 1); + } if (eqpc->d_id == eq::MERGED_THROUGH_TRANS) { std::vector<std::shared_ptr<EqProof>> orderedChildren; bool nullCongruenceFound = false; @@ -987,8 +991,13 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, if (nullCongruenceFound) { eqpc->d_children = orderedChildren; - Debug("pf::ee") << "Child proof's children have been reordered. It is now:" << std::endl; - eqpc->debug_print("pf::ee", 1); + if (Debug.isOn("pf::ee")) + { + Debug("pf::ee") + << "Child proof's children have been reordered. It is now:" + << std::endl; + eqpc->debug_print("pf::ee", 1); + } } } @@ -1011,8 +1020,11 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, *eqp = *temp; } - Debug("pf::ee") << "Disequality explanation final proof: " << std::endl; - eqp->debug_print("pf::ee", 1); + if (Debug.isOn("pf::ee")) + { + Debug("pf::ee") << "Disequality explanation final proof: " << std::endl; + eqp->debug_print("pf::ee", 1); + } } } } @@ -1024,15 +1036,51 @@ void EqualityEngine::explainPredicate(TNode p, bool polarity, << std::endl; // Must have the term Assert(hasTerm(p)); + std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*> cache; // Get the explanation - getExplanation(getNodeId(p), polarity ? d_trueId : d_falseId, assertions, - eqp); + getExplanation( + getNodeId(p), polarity ? d_trueId : d_falseId, assertions, cache, eqp); } -void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, - std::vector<TNode>& equalities, - EqProof* eqp) const { - Debug("equality") << d_name << "::eq::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl; +void EqualityEngine::getExplanation( + EqualityNodeId t1Id, + EqualityNodeId t2Id, + std::vector<TNode>& equalities, + std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*>& cache, + EqProof* eqp) const +{ + Trace("eq-exp") << d_name << "::eq::getExplanation(" << d_nodes[t1Id] << "," + << d_nodes[t2Id] << ") size = " << cache.size() << std::endl; + + // We order the ids, since explaining t1 = t2 is the same as explaining + // t2 = t1. + std::pair<EqualityNodeId, EqualityNodeId> cacheKey = std::minmax(t1Id, t2Id); + std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*>::iterator it = + cache.find(cacheKey); + if (it != cache.end()) + { + // copy one level + if (eqp) + { + if (it->second) + { + eqp->d_node = it->second->d_node; + eqp->d_id = it->second->d_id; + eqp->d_children.insert(eqp->d_children.end(), + it->second->d_children.begin(), + it->second->d_children.end()); + } + else + { + // We may have cached null in its place, create the trivial proof now. + Assert(d_nodes[t1Id] == d_nodes[t2Id]); + Assert(eqp->d_id == MERGED_THROUGH_REFLEXIVITY); + eqp->d_node = d_nodes[t1Id]; + } + } + return; + } + cache[cacheKey] = eqp; // We can only explain the nodes that got merged #ifdef CVC4_ASSERTIONS @@ -1136,11 +1184,11 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, Debug("equality") << "Explaining left hand side equalities" << std::endl; std::shared_ptr<EqProof> eqpc1 = eqpc ? std::make_shared<EqProof>() : nullptr; - getExplanation(f1.a, f2.a, equalities, eqpc1.get()); + getExplanation(f1.a, f2.a, equalities, cache, eqpc1.get()); Debug("equality") << "Explaining right hand side equalities" << std::endl; std::shared_ptr<EqProof> eqpc2 = eqpc ? std::make_shared<EqProof>() : nullptr; - getExplanation(f1.b, f2.b, equalities, eqpc2.get()); + getExplanation(f1.b, f2.b, equalities, cache, eqpc2.get()); if( eqpc ){ eqpc->d_children.push_back( eqpc1 ); eqpc->d_children.push_back( eqpc2 ); @@ -1185,7 +1233,7 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, Debug("equality") << push; std::shared_ptr<EqProof> eqpc1 = eqpc ? std::make_shared<EqProof>() : nullptr; - getExplanation(eq.a, eq.b, equalities, eqpc1.get()); + getExplanation(eq.a, eq.b, equalities, cache, eqpc1.get()); if( eqpc ){ eqpc->d_children.push_back( eqpc1 ); } @@ -1211,13 +1259,20 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, Assert(isConstant(childId)); std::shared_ptr<EqProof> eqpcc = eqpc ? std::make_shared<EqProof>() : nullptr; - getExplanation(childId, getEqualityNode(childId).getFind(), - equalities, eqpcc.get()); + getExplanation(childId, + getEqualityNode(childId).getFind(), + equalities, + cache, + eqpcc.get()); if( eqpc ) { eqpc->d_children.push_back( eqpcc ); - - Debug("pf::ee") << "MERGED_THROUGH_CONSTANTS. Dumping the child proof" << std::endl; - eqpc->debug_print("pf::ee", 1); + if (Debug.isOn("pf::ee")) + { + Debug("pf::ee") + << "MERGED_THROUGH_CONSTANTS. Dumping the child proof" + << std::endl; + eqpc->debug_print("pf::ee", 1); + } } } @@ -1255,7 +1310,6 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, } eqpc->d_id = reasonType; } - equalities.push_back(reason); break; } @@ -1288,8 +1342,10 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, eqp->d_children.insert( eqp->d_children.end(), eqp_trans.begin(), eqp_trans.end() ); eqp->d_node = NodeManager::currentNM()->mkNode(kind::EQUAL, d_nodes[t1Id], d_nodes[t2Id]); } - - eqp->debug_print("pf::ee", 1); + if (Debug.isOn("pf::ee")) + { + eqp->debug_print("pf::ee", 1); + } } // Done @@ -2236,27 +2292,48 @@ bool EqClassIterator::isFinished() const { } void EqProof::debug_print(const char* c, unsigned tb, PrettyPrinter* prettyPrinter) const { - for(unsigned i=0; i<tb; i++) { Debug( c ) << " "; } + std::stringstream ss; + debug_print(ss, tb, prettyPrinter); + Debug(c) << ss.str(); +} +void EqProof::debug_print(std::ostream& os, + unsigned tb, + PrettyPrinter* prettyPrinter) const +{ + for (unsigned i = 0; i < tb; i++) + { + os << " "; + } if (prettyPrinter) - Debug( c ) << prettyPrinter->printTag(d_id); + { + os << prettyPrinter->printTag(d_id); + } else - Debug( c ) << d_id; + { + os << d_id; + } - Debug( c ) << "("; + os << "("; if( !d_children.empty() || !d_node.isNull() ){ if( !d_node.isNull() ){ - Debug( c ) << std::endl; - for( unsigned i=0; i<tb+1; i++ ) { Debug( c ) << " "; } - Debug( c ) << d_node; + os << std::endl; + for (unsigned i = 0; i < tb + 1; i++) + { + os << " "; + } + os << d_node; } for( unsigned i=0; i<d_children.size(); i++ ){ - if( i>0 || !d_node.isNull() ) Debug( c ) << ","; - Debug( c ) << std::endl; - d_children[i]->debug_print( c, tb+1, prettyPrinter ); + if (i > 0 || !d_node.isNull()) + { + os << ","; + } + os << std::endl; + d_children[i]->debug_print(os, tb + 1, prettyPrinter); } } - Debug( c ) << ")" << std::endl; + os << ")" << std::endl; } } // Namespace uf |