diff options
author | Morgan Deters <mdeters@gmail.com> | 2010-11-17 01:39:37 +0000 |
---|---|---|
committer | Morgan Deters <mdeters@gmail.com> | 2010-11-17 01:39:37 +0000 |
commit | c7a70635797fe4205b27d29546dd4fe763220794 (patch) | |
tree | 715eb2c43beebaaa725a3064597761f60975fea6 /src/util | |
parent | bb2a0e0e12f39a1b4dea8fb0c990decba4708a1c (diff) |
The "UF engineering issues" release, after much profiling.
* swap in backtracking data structures (that use and maintain their own
undo stack) in some places instead of the usual Context-aware
paradigm (MUCH faster).
* cosmetic changes to UF, CC modules.
Diffstat (limited to 'src/util')
-rw-r--r-- | src/util/congruence_closure.h | 172 |
1 files changed, 99 insertions, 73 deletions
diff --git a/src/util/congruence_closure.h b/src/util/congruence_closure.h index db9d5bc65..90ab11f9f 100644 --- a/src/util/congruence_closure.h +++ b/src/util/congruence_closure.h @@ -34,6 +34,7 @@ #include "context/cdset.h" #include "context/cdlist_context_memory.h" #include "util/exception.h" +#include "theory/uf/morgan/stacking_map.h" namespace CVC4 { @@ -103,7 +104,7 @@ class CongruenceClosure { OutputChannel* d_out; // typedef all of these so that iterators are easy to define - typedef context::CDMap<Node, Node, NodeHashFunction> RepresentativeMap; + typedef theory::uf::morgan::StackingMap <Node, NodeHashFunction> RepresentativeMap; typedef context::CDList<TNode, context::ContextMemoryAllocator<TNode> > ClassList; typedef context::CDMap<Node, ClassList*, NodeHashFunction> ClassLists; typedef context::CDList<TNode, context::ContextMemoryAllocator<TNode> > UseList; @@ -115,6 +116,13 @@ class CongruenceClosure { typedef context::CDMap<Node, Node, NodeHashFunction> ProofMap; typedef context::CDMap<Node, Node, NodeHashFunction> ProofLabel; + // Simple, NON-context-dependent pending list, union find and "seen + // set" types for constructing explanations and + // nearestCommonAncestor(); see explain(). + typedef std::list<std::pair<Node, Node> > PendingProofList_t; + typedef __gnu_cxx::hash_map<TNode, TNode, TNodeHashFunction> UnionFind_t; + typedef __gnu_cxx::hash_set<TNode, TNodeHashFunction> SeenSet_t; + RepresentativeMap d_representative; ClassLists d_classList; UseLists d_useList; @@ -210,7 +218,7 @@ public: } } - TNode proofRewrite(TNode pfStep) { + TNode proofRewrite(TNode pfStep) const { ProofMap::const_iterator i = d_proofRewrite.find(pfStep); if(i == d_proofRewrite.end()) { return pfStep; @@ -254,18 +262,17 @@ public: * Find the EC representative for a term t in the current context. */ inline TNode find(TNode t) const throw(AssertionException) { - context::CDMap<Node, Node, NodeHashFunction>::iterator i = - d_representative.find(t); - return (i == d_representative.end()) ? t : TNode((*i).second); + TNode rep1 = d_representative.find(t); + return rep1.isNull() ? t : rep1; } - void explainAlongPath(TNode a, TNode c, std::list<std::pair<Node, Node> >& pending, __gnu_cxx::hash_map<Node, Node, NodeHashFunction>& unionFind, std::list<Node>& pf) + void explainAlongPath(TNode a, TNode c, PendingProofList_t& pending, UnionFind_t& unionFind, std::list<Node>& pf) throw(AssertionException); - Node highestNode(TNode a, __gnu_cxx::hash_map<Node, Node, NodeHashFunction>& unionFind) + Node highestNode(TNode a, UnionFind_t& unionFind) const throw(AssertionException); - Node nearestCommonAncestor(TNode a, TNode b) + Node nearestCommonAncestor(TNode a, TNode b, UnionFind_t& unionFind) throw(AssertionException); /** @@ -311,7 +318,7 @@ private: * Internal lookup mapping from tuples to equalities. */ inline TNode lookup(TNode a) const { - context::CDMap<Node, Node, NodeHashFunction>::iterator i = d_lookup.find(a); + LookupMap::iterator i = d_lookup.find(a); if(i == d_lookup.end()) { return TNode::null(); } else { @@ -343,7 +350,7 @@ private: * Append equality "eq" to uselist of "of". */ inline void appendToUseList(TNode of, TNode eq) { - Debug("cc") << "adding " << eq << " to use list of " << of << std::endl; + Trace("cc") << "adding " << eq << " to use list of " << of << std::endl; Assert(eq.getKind() == kind::EQUAL || eq.getKind() == kind::IFF); Assert(of == find(of)); @@ -406,14 +413,14 @@ template <class OutputChannel> void CongruenceClosure<OutputChannel>::addEq(TNode eq, TNode inputEq) { d_proofRewrite[eq] = inputEq; - if(Debug.isOn("cc")) { - Debug("cc") << "CC addEq[" << d_context->getLevel() << "]: " << eq << std::endl; + if(Trace.isOn("cc")) { + Trace("cc") << "CC addEq[" << d_context->getLevel() << "]: " << eq << std::endl; } Assert(eq.getKind() == kind::EQUAL || eq.getKind() == kind::IFF); Assert(eq[1].getKind() != kind::APPLY_UF); if(areCongruent(eq[0], eq[1])) { - Debug("cc") << "CC -- redundant, ignoring...\n"; + Trace("cc") << "CC -- redundant, ignoring...\n"; return; } @@ -421,10 +428,10 @@ void CongruenceClosure<OutputChannel>::addEq(TNode eq, TNode inputEq) { Assert(s != t); - Debug("cc:detail") << "CC " << s << " == " << t << std::endl; + Trace("cc:detail") << "CC " << s << " == " << t << std::endl; // change from paper: do this whether or not s, t are applications - Debug("cc:detail") << "CC propagating the eq" << std::endl; + Trace("cc:detail") << "CC propagating the eq" << std::endl; if(s.getKind() != kind::APPLY_UF) { // s, t are constants @@ -433,26 +440,24 @@ void CongruenceClosure<OutputChannel>::addEq(TNode eq, TNode inputEq) { // s is an apply, t is a constant Node ap = buildRepresentativesOfApply(s); - Debug("cc:detail") << "CC rewrLHS " << "op_and_args_a == " << t << std::endl; - Debug("cc") << "CC ap is " << ap << std::endl; + Trace("cc:detail") << "CC rewrLHS " << "op_and_args_a == " << t << std::endl; + Trace("cc") << "CC ap is " << ap << std::endl; TNode l = lookup(ap); - Debug("cc:detail") << "CC lookup(ap): " << l << std::endl; + Trace("cc:detail") << "CC lookup(ap): " << l << std::endl; if(!l.isNull()) { // ensure a hard Node link exists to this during the call Node pending = NodeManager::currentNM()->mkNode(kind::TUPLE, eq, l); - Debug("cc:detail") << "pending1 " << pending << std::endl; + Trace("cc:detail") << "pending1 " << pending << std::endl; propagate(pending); } else { - Debug("cc") << "CC lookup(ap) setting to " << eq << std::endl; + Trace("cc") << "CC lookup(ap) setting to " << eq << std::endl; setLookup(ap, eq); for(Node::iterator i = ap.begin(); i != ap.end(); ++i) { appendToUseList(*i, eq); } } } - - Debug("cc") << *this; }/* addEq() */ @@ -462,6 +467,7 @@ Node CongruenceClosure<OutputChannel>::buildRepresentativesOfApply(TNode apply, throw(AssertionException) { Assert(apply.getKind() == kind::APPLY_UF); NodeBuilder<> argspb(kindToBuild); + // FIXME probably don't have to do find() of operator argspb << find(apply.getOperator()); for(TNode::iterator i = apply.begin(); i != apply.end(); ++i) { argspb << find(*i); @@ -472,18 +478,18 @@ Node CongruenceClosure<OutputChannel>::buildRepresentativesOfApply(TNode apply, template <class OutputChannel> void CongruenceClosure<OutputChannel>::propagate(TNode seed) { - Debug("cc:detail") << "=== doing a round of propagation ===" << std::endl + Trace("cc:detail") << "=== doing a round of propagation ===" << std::endl << "the \"seed\" propagation is: " << seed << std::endl; std::list<Node> pending; - Debug("cc") << "seed propagation with: " << seed << std::endl; + Trace("cc") << "seed propagation with: " << seed << std::endl; pending.push_back(seed); do { // while(!pending.empty()) Node e = pending.front(); pending.pop_front(); - Debug("cc") << "=== top of propagate loop ===" << std::endl + Trace("cc") << "=== top of propagate loop ===" << std::endl << "=== e is " << e << " ===" << std::endl; TNode a, b; @@ -494,11 +500,11 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) { a = e[0]; b = e[1]; - Debug("cc:detail") << "propagate equality: " << a << " == " << b << std::endl; + Trace("cc:detail") << "propagate equality: " << a << " == " << b << std::endl; } else { // e is a tuple ( apply f A... = a , apply f B... = b ) - Debug("cc") << "propagate tuple: " << e << "\n"; + Trace("cc") << "propagate tuple: " << e << "\n"; Assert(e.getKind() == kind::TUPLE); @@ -515,11 +521,11 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) { Assert(a.getKind() != kind::APPLY_UF); Assert(b.getKind() != kind::APPLY_UF); - Debug("cc") << " ( " << a << " , " << b << " )" << std::endl; + Trace("cc") << " ( " << a << " , " << b << " )" << std::endl; } if(Debug.isOn("cc")) { - Debug("cc:detail") << "=====at start=====" << std::endl + Trace("cc:detail") << "=====at start=====" << std::endl << "a :" << a << std::endl << "NORMALIZE a:" << normalize(a) << std::endl << "b :" << b << std::endl @@ -532,7 +538,7 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) { Node ap = find(a), bp = find(b); if(ap != bp) { - Debug("cc:detail") << "EC[a] == " << ap << std::endl + Trace("cc:detail") << "EC[a] == " << ap << std::endl << "EC[b] == " << bp << std::endl; { // w.l.o.g., |classList ap| <= |classList bp| @@ -540,9 +546,9 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) { ClassLists::iterator cl_bpi = d_classList.find(bp); unsigned sizeA = (cl_api == d_classList.end() ? 0 : (*cl_api).second->size()); unsigned sizeB = (cl_bpi == d_classList.end() ? 0 : (*cl_bpi).second->size()); - Debug("cc") << "sizeA == " << sizeA << " sizeB == " << sizeB << std::endl; + Trace("cc") << "sizeA == " << sizeA << " sizeB == " << sizeB << std::endl; if(sizeA > sizeB) { - Debug("cc") << "swapping..\n"; + Trace("cc") << "swapping..\n"; TNode tmp = ap; ap = bp; bp = tmp; tmp = a; a = b; b = tmp; } @@ -555,17 +561,17 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) { cl_bp = new(d_context->getCMM()) ClassList(true, d_context, false, context::ContextMemoryAllocator<TNode>(d_context->getCMM())); d_classList.insertDataFromContextMemory(bp, cl_bp); - Debug("cc:detail") << "CC in prop alloc classlist for " << bp << std::endl; + Trace("cc:detail") << "CC in prop alloc classlist for " << bp << std::endl; } else { cl_bp = (*cl_bpi).second; } // we don't store 'ap' in its own class list; so process it here - Debug("cc:detail") << "calling mergeproof/merge1 " << ap + Trace("cc:detail") << "calling mergeproof/merge1 " << ap << " AND " << bp << std::endl; mergeProof(a, b, e); merge(ap, bp); - Debug("cc") << " adding ap == " << ap << std::endl + Trace("cc") << " adding ap == " << ap << std::endl << " to class list of " << bp << std::endl; cl_bp->push_back(ap); ClassLists::const_iterator cli = d_classList.find(ap); @@ -581,15 +587,15 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) { Debug("cc") << " it's find ptr is: " << find(c) << std::endl; } Assert(find(c) == ap); - Debug("cc:detail") << "calling merge2 " << c << bp << std::endl; + Trace("cc:detail") << "calling merge2 " << c << bp << std::endl; merge(c, bp); // move c from classList(ap) to classlist(bp); //i = cl.erase(i);// FIXME do we need to? - Debug("cc") << " adding c to class list of " << bp << std::endl; + Trace("cc") << " adding c to class list of " << bp << std::endl; cl_bp->push_back(c); } } - Debug("cc:detail") << "bottom\n"; + Trace("cc:detail") << "bottom\n"; } { // use list handling @@ -606,7 +612,7 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) { i != ul->end(); ++i) { TNode eq = *i; - Debug("cc") << "CC -- useList: " << eq << std::endl; + Trace("cc") << "CC -- useList: " << eq << std::endl; Assert(eq.getKind() == kind::EQUAL || eq.getKind() == kind::IFF); // change from paper @@ -623,21 +629,21 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) { // if lookup(c1',c2') is some f(d1,d2)=d then TNode n = lookup(cp); - Debug("cc:detail") << "CC -- c' is " << cp << std::endl; + Trace("cc:detail") << "CC -- c' is " << cp << std::endl; if(!n.isNull()) { - Debug("cc:detail") << "CC -- lookup(c') is " << n << std::endl; + Trace("cc:detail") << "CC -- lookup(c') is " << n << std::endl; // add (f(c1,c2)=c,f(d1,d2)=d) to pending Node tuple = NodeManager::currentNM()->mkNode(kind::TUPLE, eq, n); - Debug("cc") << "CC add tuple to pending: " << tuple << std::endl; + Trace("cc") << "CC add tuple to pending: " << tuple << std::endl; pending.push_back(tuple); // remove f(c1,c2)=c from UseList(ap) - Debug("cc:detail") << "supposed to remove " << eq << std::endl + Trace("cc:detail") << "supposed to remove " << eq << std::endl << " from UseList of " << ap << std::endl; //i = ul.erase(i);// FIXME do we need to? } else { - Debug("cc") << "CC -- lookup(c') is null" << std::endl; - Debug("cc") << "CC -- setlookup(c') to " << eq << std::endl; + Trace("cc") << "CC -- lookup(c') is null" << std::endl; + Trace("cc") << "CC -- setlookup(c') to " << eq << std::endl; // set lookup(c1',c2') to f(c1,c2)=c setLookup(cp, eq); // move f(c1,c2)=c from UseList(ap) to UseList(b') @@ -648,9 +654,9 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) { } } }/* use lists */ - Debug("cc:detail") << "CC in prop done with useList of " << ap << std::endl; + Trace("cc:detail") << "CC in prop done with useList of " << ap << std::endl; } else { - Debug("cc:detail") << "CCs the same ( == " << ap << "), do nothing." << std::endl; + Trace("cc:detail") << "CCs the same ( == " << ap << "), do nothing." << std::endl; } if(Debug.isOn("cc")) { @@ -679,8 +685,8 @@ void CongruenceClosure<OutputChannel>::merge(TNode ec1, TNode ec2) { } */ - Debug("cc") << "CC setting rep of " << ec1 << std::endl; - Debug("cc") << "CC to " << ec2 << std::endl; + Trace("cc") << "CC setting rep of " << ec1 << std::endl; + Trace("cc") << "CC to " << ec2 << std::endl; /* can now be applications Assert(ec1.getKind() != kind::APPLY_UF); @@ -691,7 +697,7 @@ void CongruenceClosure<OutputChannel>::merge(TNode ec1, TNode ec2) { //Assert(find(ec1) == ec1); Assert(find(ec2) == ec2); - d_representative[ec1] = ec2; + d_representative.set(ec1, ec2); if(d_careSet.find(ec1) != d_careSet.end()) { d_careSet.insert(ec2); @@ -702,18 +708,18 @@ void CongruenceClosure<OutputChannel>::merge(TNode ec1, TNode ec2) { template <class OutputChannel> void CongruenceClosure<OutputChannel>::mergeProof(TNode a, TNode b, TNode e) { - Debug("cc") << " -- merge-proofing " << a << "\n" + Trace("cc") << " -- merge-proofing " << a << "\n" << " and " << b << "\n" << " with " << e << "\n"; // proof forest gets a -> b labeled with e // first reverse all the edges in proof forest to root of this proof tree - Debug("cc") << "CC PROOF reversing proof tree\n"; + Trace("cc") << "CC PROOF reversing proof tree\n"; // c and p are child and parent in (old) proof tree Node c = a, p = d_proof[a], edgePf = d_proofLabel[a]; // when we hit null p, we're at the (former) root - Debug("cc") << "CC PROOF start at c == " << c << std::endl + Trace("cc") << "CC PROOF start at c == " << c << std::endl << " p == " << p << std::endl << " edgePf == " << edgePf << std::endl; while(!p.isNull()) { @@ -728,7 +734,7 @@ void CongruenceClosure<OutputChannel>::mergeProof(TNode a, TNode b, TNode e) { c = p; p = pParSave; edgePf = pLabelSave; - Debug("cc") << "CC PROOF now at c == " << c << std::endl + Trace("cc") << "CC PROOF now at c == " << c << std::endl << " p == " << p << std::endl << " edgePf == " << edgePf << std::endl; } @@ -742,10 +748,10 @@ void CongruenceClosure<OutputChannel>::mergeProof(TNode a, TNode b, TNode e) { template <class OutputChannel> Node CongruenceClosure<OutputChannel>::normalize(TNode t) const throw(AssertionException) { - Debug("cc:detail") << "normalize " << t << std::endl; + Trace("cc:detail") << "normalize " << t << std::endl; if(t.getKind() != kind::APPLY_UF) {// t is a constant t = find(t); - Debug("cc:detail") << " find " << t << std::endl; + Trace("cc:detail") << " find " << t << std::endl; return t; } else {// t is an apply NodeBuilder<> apb(kind::TUPLE); @@ -761,7 +767,7 @@ Node CongruenceClosure<OutputChannel>::normalize(TNode t) const } Node ap = apb; - Debug("cc:detail") << " got ap " << ap << std::endl; + Trace("cc:detail") << " got ap " << ap << std::endl; Node theLookup = lookup(ap); if(allConstants && !theLookup.isNull()) { @@ -783,10 +789,13 @@ Node CongruenceClosure<OutputChannel>::normalize(TNode t) const }/* normalize() */ +// This is the find() operation for the auxiliary union-find. This +// union-find is not context-dependent, as it's used only during +// explain(). It does path compression. template <class OutputChannel> -Node CongruenceClosure<OutputChannel>::highestNode(TNode a, __gnu_cxx::hash_map<Node, Node, NodeHashFunction>& unionFind) +Node CongruenceClosure<OutputChannel>::highestNode(TNode a, UnionFind_t& unionFind) const throw(AssertionException) { - __gnu_cxx::hash_map<Node, Node, NodeHashFunction>::iterator i = unionFind.find(a); + UnionFind_t::iterator i = unionFind.find(a); if(i == unionFind.end()) { return a; } else { @@ -796,7 +805,7 @@ Node CongruenceClosure<OutputChannel>::highestNode(TNode a, __gnu_cxx::hash_map< template <class OutputChannel> -void CongruenceClosure<OutputChannel>::explainAlongPath(TNode a, TNode c, std::list<std::pair<Node, Node> >& pending, __gnu_cxx::hash_map<Node, Node, NodeHashFunction>& unionFind, std::list<Node>& pf) +void CongruenceClosure<OutputChannel>::explainAlongPath(TNode a, TNode c, PendingProofList_t& pending, UnionFind_t& unionFind, std::list<Node>& pf) throw(AssertionException) { a = highestNode(a, unionFind); @@ -829,10 +838,27 @@ void CongruenceClosure<OutputChannel>::explainAlongPath(TNode a, TNode c, std::l template <class OutputChannel> -Node CongruenceClosure<OutputChannel>::nearestCommonAncestor(TNode a, TNode b) +Node CongruenceClosure<OutputChannel>::nearestCommonAncestor(TNode a, TNode b, UnionFind_t& unionFind) throw(AssertionException) { + SeenSet_t seen; + Assert(find(a) == find(b)); - return find(a); // FIXME + + do { + a = highestNode(a, unionFind); + seen.insert(a); + a = d_proof[a]; + } while(!a.isNull()); + + for(;;) { + b = highestNode(b, unionFind); + if(seen.find(b) != seen.end()) { + return b; + } + b = d_proof[b]; + + Assert(!b.isNull()); + } }/* nearestCommonAncestor() */ @@ -854,13 +880,13 @@ Node CongruenceClosure<OutputChannel>::explain(Node a, Node b) b = replace(flatten(b)); } - std::list<std::pair<Node, Node> > pending; - __gnu_cxx::hash_map<Node, Node, NodeHashFunction> unionFind; + PendingProofList_t pending; + UnionFind_t unionFind; std::list<Node> terms; pending.push_back(std::make_pair(a, b)); - Debug("cc") << "CC EXPLAINING " << a << " == " << b << std::endl; + Trace("cc") << "CC EXPLAINING " << a << " == " << b << std::endl; do {// while(!pending.empty()) std::pair<Node, Node> eq = pending.front(); @@ -869,29 +895,29 @@ Node CongruenceClosure<OutputChannel>::explain(Node a, Node b) a = eq.first; b = eq.second; - Node c = nearestCommonAncestor(a, b); + Node c = nearestCommonAncestor(a, b, unionFind); explainAlongPath(a, c, pending, unionFind, terms); explainAlongPath(b, c, pending, unionFind, terms); } while(!pending.empty()); - if(Debug.isOn("cc")) { - Debug("cc") << "CC EXPLAIN final proof has size " << terms.size() << std::endl; + if(Trace.isOn("cc")) { + Trace("cc") << "CC EXPLAIN final proof has size " << terms.size() << std::endl; } NodeBuilder<> pf(kind::AND); while(!terms.empty()) { Node p = terms.front(); terms.pop_front(); - Debug("cc") << "CC EXPLAIN " << p << std::endl; + Trace("cc") << "CC EXPLAIN " << p << std::endl; p = proofRewrite(p); - Debug("cc") << " rewrite " << p << std::endl; + Trace("cc") << " rewrite " << p << std::endl; if(!p.isNull()) { pf << p; } } - Debug("cc") << "CC EXPLAIN done" << std::endl; + Trace("cc") << "CC EXPLAIN done" << std::endl; Assert(pf.getNumChildren() > 0); @@ -908,10 +934,10 @@ std::ostream& operator<<(std::ostream& out, const CongruenceClosure<OutputChannel>& cc) { out << "==============================================" << std::endl; - out << "Representatives:" << std::endl; + /*out << "Representatives:" << std::endl; for(typename CongruenceClosure<OutputChannel>::RepresentativeMap::const_iterator i = cc.d_representative.begin(); i != cc.d_representative.end(); ++i) { out << " " << (*i).first << " => " << (*i).second << std::endl; - } + }*/ out << "ClassLists:" << std::endl; for(typename CongruenceClosure<OutputChannel>::ClassLists::const_iterator i = cc.d_classList.begin(); i != cc.d_classList.end(); ++i) { |