diff options
author | Dejan Jovanović <dejan.jovanovic@gmail.com> | 2012-05-09 21:25:17 +0000 |
---|---|---|
committer | Dejan Jovanović <dejan.jovanovic@gmail.com> | 2012-05-09 21:25:17 +0000 |
commit | 1ce0650dcf8ce30424b546deb540974cc510c215 (patch) | |
tree | 74a9985463234fc9adfed2de209c134ed7da359b /src/theory/uf | |
parent | 690fb2843d9845e405fee54eb2d8023eebbd5b72 (diff) |
* simplifying equality engine interface
* notifications are now through the interface subclass instead of a template
* notifications include constants being merged
* changed contextNotifyObj::notify to contextNotifyObj::contextNotifyPop so it's more descriptive and doesn't clutter methods when subclassed
* sat solver now has explicit methods to make true and false constants
* 0-level literals are removed from explanations of propagations
Diffstat (limited to 'src/theory/uf')
-rw-r--r-- | src/theory/uf/Makefile.am | 2 | ||||
-rw-r--r-- | src/theory/uf/equality_engine.cpp (renamed from src/theory/uf/equality_engine_impl.h) | 376 | ||||
-rw-r--r-- | src/theory/uf/equality_engine.h | 261 | ||||
-rw-r--r-- | src/theory/uf/theory_uf.cpp | 115 | ||||
-rw-r--r-- | src/theory/uf/theory_uf.h | 56 |
5 files changed, 456 insertions, 354 deletions
diff --git a/src/theory/uf/Makefile.am b/src/theory/uf/Makefile.am index f25e50ec9..9d95eaa22 100644 --- a/src/theory/uf/Makefile.am +++ b/src/theory/uf/Makefile.am @@ -11,7 +11,7 @@ libuf_la_SOURCES = \ theory_uf_type_rules.h \ theory_uf_rewriter.h \ equality_engine.h \ - equality_engine_impl.h \ + equality_engine.cpp \ symmetry_breaker.h \ symmetry_breaker.cpp diff --git a/src/theory/uf/equality_engine_impl.h b/src/theory/uf/equality_engine.cpp index be12e5f19..b78015c00 100644 --- a/src/theory/uf/equality_engine_impl.h +++ b/src/theory/uf/equality_engine.cpp @@ -17,15 +17,27 @@ ** \todo document this file **/ -#include "cvc4_private.h" - -#pragma once - #include "theory/uf/equality_engine.h" namespace CVC4 { namespace theory { -namespace uf { +namespace eq { + +/** + * Data used in the BFS search through the equality graph. + */ +struct BfsData { + // The current node + EqualityNodeId nodeId; + // The index of the edge we traversed + EqualityEdgeId edgeId; + // Index in the queue of the previous node. Shouldn't be too much of them, at most the size + // of the biggest equivalence class + size_t previousIndex; + + BfsData(EqualityNodeId nodeId = null_id, EqualityEdgeId edgeId = null_edge, size_t prev = 0) + : nodeId(nodeId), edgeId(edgeId), previousIndex(prev) {} +}; class ScopedBool { bool& watch; @@ -40,20 +52,63 @@ public: } }; -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::enqueue(const MergeCandidate& candidate) { +EqualityEngineNotifyNone EqualityEngine::s_notifyNone; + +void EqualityEngine::init() { + Debug("equality") << "EqualityEdge::EqualityEngine(): id_null = " << +null_id << std::endl; + Debug("equality") << "EqualityEdge::EqualityEngine(): edge_null = " << +null_edge << std::endl; + Debug("equality") << "EqualityEdge::EqualityEngine(): trigger_null = " << +null_trigger << std::endl; + d_true = NodeManager::currentNM()->mkConst<bool>(true); + d_false = NodeManager::currentNM()->mkConst<bool>(false); + addTerm(d_true); + addTerm(d_false); +} + + +EqualityEngine::EqualityEngine(context::Context* context, std::string name) +: ContextNotifyObj(context) +, d_context(context) +, d_performNotify(true) +, d_notify(s_notifyNone) +, d_applicationLookupsCount(context, 0) +, d_nodesCount(context, 0) +, d_assertedEqualitiesCount(context, 0) +, d_equalityTriggersCount(context, 0) +, d_individualTriggersSize(context, 0) +, d_constantRepresentativesSize(context, 0) +, d_stats(name) +{ + init(); +} + +EqualityEngine::EqualityEngine(EqualityEngineNotify& notify, context::Context* context, std::string name) +: ContextNotifyObj(context) +, d_context(context) +, d_performNotify(true) +, d_notify(notify) +, d_applicationLookupsCount(context, 0) +, d_nodesCount(context, 0) +, d_assertedEqualitiesCount(context, 0) +, d_equalityTriggersCount(context, 0) +, d_individualTriggersSize(context, 0) +, d_constantRepresentativesSize(context, 0) +, d_stats(name) +{ + init(); +} + +void EqualityEngine::enqueue(const MergeCandidate& candidate) { Debug("equality") << "EqualityEngine::enqueue(" << candidate.toString(*this) << ")" << std::endl; d_propagationQueue.push(candidate); } -template <typename NotifyClass> -EqualityNodeId EqualityEngine<NotifyClass>::newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2) { +EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2) { Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ")" << std::endl; ++ d_stats.functionTermsCount; // Get another id for this - EqualityNodeId funId = newNode(original, true); + EqualityNodeId funId = newNode(original); FunctionApplication funOriginal(t1, t2); // The function application we're creating EqualityNodeId t1ClassId = getEqualityNode(t1).getFind(); @@ -87,10 +142,9 @@ EqualityNodeId EqualityEngine<NotifyClass>::newApplicationNode(TNode original, E return funId; } -template <typename NotifyClass> -EqualityNodeId EqualityEngine<NotifyClass>::newNode(TNode node, bool isApplication) { +EqualityNodeId EqualityEngine::newNode(TNode node) { - Debug("equality") << "EqualityEngine::newNode(" << node << ", " << (isApplication ? "function" : "regular") << ")" << std::endl; + Debug("equality") << "EqualityEngine::newNode(" << node << ")" << std::endl; ++ d_stats.termsCount; @@ -107,20 +161,20 @@ EqualityNodeId EqualityEngine<NotifyClass>::newNode(TNode node, bool isApplicati d_equalityGraph.push_back(+null_edge); // Mark the no-individual trigger d_nodeIndividualTrigger.push_back(+null_id); + // Mark non-constant by default + d_constantRepresentative.push_back(node.isConst() ? newId : +null_id); // Add the equality node to the nodes d_equalityNodes.push_back(EqualityNode(newId)); // Increase the counters d_nodesCount = d_nodesCount + 1; - Debug("equality") << "EqualityEngine::newNode(" << node << ", " << (isApplication ? "function" : "regular") << ") => " << newId << std::endl; + Debug("equality") << "EqualityEngine::newNode(" << node << ") => " << newId << std::endl; return newId; } - -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::addTerm(TNode t) { +void EqualityEngine::addTerm(TNode t) { Debug("equality") << "EqualityEngine::addTerm(" << t << ")" << std::endl; @@ -148,47 +202,40 @@ void EqualityEngine<NotifyClass>::addTerm(TNode t) { } } else { // Otherwise we just create the new id - result = newNode(t, false); + result = newNode(t); } Debug("equality") << "EqualityEngine::addTerm(" << t << ") => " << result << std::endl; } -template <typename NotifyClass> -bool EqualityEngine<NotifyClass>::hasTerm(TNode t) const { +bool EqualityEngine::hasTerm(TNode t) const { return d_nodeIds.find(t) != d_nodeIds.end(); } -template <typename NotifyClass> -EqualityNodeId EqualityEngine<NotifyClass>::getNodeId(TNode node) const { +EqualityNodeId EqualityEngine::getNodeId(TNode node) const { Assert(hasTerm(node), node.toString().c_str()); return (*d_nodeIds.find(node)).second; } -template <typename NotifyClass> -EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(TNode t) { +EqualityNode& EqualityEngine::getEqualityNode(TNode t) { return getEqualityNode(getNodeId(t)); } -template <typename NotifyClass> -EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(EqualityNodeId nodeId) { +EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) { Assert(nodeId < d_equalityNodes.size()); return d_equalityNodes[nodeId]; } -template <typename NotifyClass> -const EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(TNode t) const { +const EqualityNode& EqualityEngine::getEqualityNode(TNode t) const { return getEqualityNode(getNodeId(t)); } -template <typename NotifyClass> -const EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(EqualityNodeId nodeId) const { +const EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) const { Assert(nodeId < d_equalityNodes.size()); return d_equalityNodes[nodeId]; } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::addEqualityInternal(TNode t1, TNode t2, TNode reason) { +void EqualityEngine::assertEqualityInternal(TNode t1, TNode t2, TNode reason) { Debug("equality") << "EqualityEngine::addEqualityInternal(" << t1 << "," << t2 << ")" << std::endl; @@ -204,55 +251,35 @@ void EqualityEngine<NotifyClass>::addEqualityInternal(TNode t1, TNode t2, TNode propagate(); } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::addPredicate(TNode t, bool polarity, TNode reason) { - +void EqualityEngine::assertPredicate(TNode t, bool polarity, TNode reason) { Debug("equality") << "EqualityEngine::addPredicate(" << t << "," << (polarity ? "true" : "false") << ")" << std::endl; - - addEqualityInternal(t, polarity ? d_true : d_false, reason); -} - -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::addEquality(TNode t1, TNode t2, TNode reason) { - - Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << ")" << std::endl; - - addEqualityInternal(t1, t2, reason); - - Node equality = t1.eqNode(t2); - addEqualityInternal(equality, d_true, reason); + Assert(t.getKind() != kind::EQUAL, "Use assertEquality instead"); + assertEqualityInternal(t, polarity ? d_true : d_false, reason); } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::addDisequality(TNode t1, TNode t2, TNode reason) { - - Debug("equality") << "EqualityEngine::addDisequality(" << t1 << "," << t2 << ")" << std::endl; - - Node equality1 = t1.eqNode(t2); - addEqualityInternal(equality1, d_false, reason); - - Node equality2 = t2.eqNode(t1); - addEqualityInternal(equality2, d_false, reason); +void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason) { + Debug("equality") << "EqualityEngine::addEquality(" << eq << "," << (polarity ? "true" : "false") << std::endl; + if (polarity) { + // Add equality between terms + assertEqualityInternal(eq[0], eq[1], reason); + // Add eq = true for dis-equality propagation + assertEqualityInternal(eq, d_true, reason); + } else { + assertEqualityInternal(eq, d_false, reason); + Node eqSymm = eq[1].eqNode(eq[0]); + assertEqualityInternal(eqSymm, d_false, reason); + } } - -template <typename NotifyClass> -TNode EqualityEngine<NotifyClass>::getRepresentative(TNode t) const { - +TNode EqualityEngine::getRepresentative(TNode t) const { Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ")" << std::endl; - Assert(hasTerm(t)); - - // Both following commands are semantically const EqualityNodeId representativeId = getEqualityNode(t).getFind(); - Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ") => " << d_nodes[representativeId] << std::endl; - return d_nodes[representativeId]; } -template <typename NotifyClass> -bool EqualityEngine<NotifyClass>::areEqual(TNode t1, TNode t2) const { +bool EqualityEngine::areEqual(TNode t1, TNode t2) const { Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl; Assert(hasTerm(t1)); @@ -267,8 +294,7 @@ bool EqualityEngine<NotifyClass>::areEqual(TNode t1, TNode t2) const { return rep1 == rep2; } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::merge(EqualityNode& class1, EqualityNode& class2, std::vector<TriggerId>& triggers) { +bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector<TriggerId>& triggers) { Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl; @@ -357,10 +383,30 @@ void EqualityEngine<NotifyClass>::merge(EqualityNode& class1, EqualityNode& clas // Now merge the lists class1.merge<true>(class2); - // Notfiy the triggers - EqualityNodeId class1triggerId = d_nodeIndividualTrigger[class1Id]; + // Check for constants + EqualityNodeId class2constId = d_constantRepresentative[class2Id]; + if (class2constId != +null_id) { + EqualityNodeId class1constId = d_constantRepresentative[class1Id]; + if (class1constId != +null_id) { + if (d_performNotify) { + TNode const1 = d_nodes[class1constId]; + TNode const2 = d_nodes[class2constId]; + if (!d_notify.eqNotifyConstantTermMerge(const1, const2)) { + return false; + } + } + } else { + // If the class we're merging in is constant, mark the representative as constant + d_constantRepresentative[class1Id] = d_constantRepresentative[class2Id]; + d_constantRepresentatives.push_back(class1Id); + d_constantRepresentativesSize = d_constantRepresentativesSize + 1; + } + } + + // Notify the trigger term merges EqualityNodeId class2triggerId = d_nodeIndividualTrigger[class2Id]; if (class2triggerId != +null_id) { + EqualityNodeId class1triggerId = d_nodeIndividualTrigger[class1Id]; if (class1triggerId == +null_id) { // If class1 is not an individual trigger, but class2 is, mark it d_nodeIndividualTrigger[class1Id] = class2triggerId; @@ -370,14 +416,18 @@ void EqualityEngine<NotifyClass>::merge(EqualityNode& class1, EqualityNode& clas } else { // Notify when done if (d_performNotify) { - d_notify.notify(d_nodes[class1triggerId], d_nodes[class2triggerId]); + if (!d_notify.eqNotifyTriggerTermEquality(d_nodes[class1triggerId], d_nodes[class2triggerId], true)) { + return false; + } } } } + + // Everything fine + return true; } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id) { +void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id) { Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << ")" << std::endl; @@ -409,8 +459,7 @@ void EqualityEngine<NotifyClass>::undoMerge(EqualityNode& class1, EqualityNode& } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::backtrack() { +void EqualityEngine::backtrack() { Debug("equality::backtrack") << "backtracking" << std::endl; @@ -453,6 +502,14 @@ void EqualityEngine<NotifyClass>::backtrack() { d_individualTriggers.resize(d_individualTriggersSize); } + if (d_constantRepresentatives.size() > d_constantRepresentativesSize) { + // Unset the constant representatives + for (int i = d_constantRepresentatives.size() - 1, i_end = d_constantRepresentativesSize; i >= i_end; -- i) { + d_constantRepresentative[d_constantRepresentatives[i]] = +null_id; + } + d_constantRepresentatives.resize(d_constantRepresentativesSize); + } + if (d_equalityTriggers.size() > d_equalityTriggersCount) { // Unlink the triggers from the lists for (int i = d_equalityTriggers.size() - 1, i_end = d_equalityTriggersCount; i >= i_end; -- i) { @@ -492,13 +549,13 @@ void EqualityEngine<NotifyClass>::backtrack() { d_applications.resize(d_nodesCount); d_nodeTriggers.resize(d_nodesCount); d_nodeIndividualTrigger.resize(d_nodesCount); + d_constantRepresentative.resize(d_nodesCount); d_equalityGraph.resize(d_nodesCount); d_equalityNodes.resize(d_nodesCount); } } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::addGraphEdge(EqualityNodeId t1, EqualityNodeId t2, MergeReasonType type, TNode reason) { +void EqualityEngine::addGraphEdge(EqualityNodeId t1, EqualityNodeId t2, MergeReasonType type, TNode reason) { Debug("equality") << "EqualityEngine::addGraphEdge(" << d_nodes[t1] << "," << d_nodes[t2] << "," << reason << ")" << std::endl; EqualityEdgeId edge = d_equalityEdges.size(); d_equalityEdges.push_back(EqualityEdge(t2, d_equalityGraph[t1], type, reason)); @@ -511,8 +568,7 @@ void EqualityEngine<NotifyClass>::addGraphEdge(EqualityNodeId t1, EqualityNodeId } } -template <typename NotifyClass> -std::string EqualityEngine<NotifyClass>::edgesToString(EqualityEdgeId edgeId) const { +std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const { std::stringstream out; bool first = true; if (edgeId == null_edge) { @@ -529,70 +585,52 @@ std::string EqualityEngine<NotifyClass>::edgesToString(EqualityEdgeId edgeId) co return out.str(); } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::explainEquality(TNode t1, TNode t2, std::vector<TNode>& equalities) { +void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, std::vector<TNode>& equalities) { Debug("equality") << "EqualityEngine::explainEquality(" << t1 << "," << t2 << ")" << std::endl; // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); + ScopedBool turnOffNotify(d_performNotify, false); // Add the terms (they might not be there) addTerm(t1); addTerm(t2); - Assert(getRepresentative(t1) == getRepresentative(t2), - "Cannot explain an equality, because the two terms are not equal!\n" - "The representative of %s\n" - " is %s\n" - "The representative of %s\n" - " is %s", - t1.toString().c_str(), getRepresentative(t1).toString().c_str(), - t2.toString().c_str(), getRepresentative(t2).toString().c_str()); - - // Get the explanation - EqualityNodeId t1Id = getNodeId(t1); - EqualityNodeId t2Id = getNodeId(t2); - getExplanation(t1Id, t2Id, equalities); - + if (polarity) { + // Get the explanation + EqualityNodeId t1Id = getNodeId(t1); + EqualityNodeId t2Id = getNodeId(t2); + getExplanation(t1Id, t2Id, equalities); + } else { + // Add the equality + Node equality = t1.eqNode(t2); + addTerm(equality); + + // Get the explanation + EqualityNodeId equalityId = getNodeId(equality); + EqualityNodeId falseId = getNodeId(d_false); + getExplanation(equalityId, falseId, equalities); + } } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::explainDisequality(TNode t1, TNode t2, std::vector<TNode>& equalities) { - Debug("equality") << "EqualityEngine::explainDisequality(" << t1 << "," << t2 << ")" << std::endl; +void EqualityEngine::explainPredicate(TNode p, bool polarity, std::vector<TNode>& assertions) { + Debug("equality") << "EqualityEngine::explainEquality(" << p << ")" << std::endl; // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); + ScopedBool turnOffNotify(d_performNotify, false); // Add the terms - addTerm(t1); - addTerm(t2); - - // Add the equality - Node equality = t1.eqNode(t2); - addTerm(equality); - - Assert(getRepresentative(equality) == getRepresentative(d_false), - "Cannot explain the dis-equality, because the two terms are not dis-equal!\n" - "The representative of %s\n" - " is %s\n" - "The representative of %s\n" - " is %s", - equality.toString().c_str(), getRepresentative(equality).toString().c_str(), - d_false.toString().c_str(), getRepresentative(d_false).toString().c_str()); - - // Get the explanation - EqualityNodeId equalityId = getNodeId(equality); - EqualityNodeId falseId = getNodeId(d_false); - getExplanation(equalityId, falseId, equalities); + addTerm(p); + // Get the explanation + getExplanation(getNodeId(p), getNodeId(polarity ? d_true : d_false), assertions); } - -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, std::vector<TNode>& equalities) const { +void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, std::vector<TNode>& equalities) const { Debug("equality") << "EqualityEngine::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl; + Assert(getEqualityNode(t1Id).getFind() == getEqualityNode(t2Id).getFind()); + // If the nodes are the same, we're done if (t1Id == t2Id) return; @@ -682,15 +720,28 @@ void EqualityEngine<NotifyClass>::getExplanation(EqualityNodeId t1Id, EqualityNo } } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::addTriggerDisequality(TNode t1, TNode t2, TNode trigger) { - Node equality = t1.eqNode(t2); - addTerm(equality); - addTriggerEquality(equality, d_false, trigger); +void EqualityEngine::addTriggerEquality(TNode eq) { + Assert(eq.getKind() == kind::EQUAL); + // Add the terms + addTerm(eq); + // Positive trigger + addTriggerEqualityInternal(eq[0], eq[1], eq, true); + // Negative trigger + addTriggerEqualityInternal(eq, d_false, eq, false); +} + +void EqualityEngine::addTriggerPredicate(TNode predicate) { + Assert(predicate.getKind() != kind::NOT && predicate.getKind() != kind::EQUAL); + Assert(d_congruenceKinds.tst(predicate.getKind()), "No point in adding non-congruence predicates"); + // Add the term + addTerm(predicate); + // Positive trigger + addTriggerEqualityInternal(predicate, d_true, predicate, true); + // Negative trigger + addTriggerEqualityInternal(predicate, d_false, predicate, false); } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::addTriggerEquality(TNode t1, TNode t2, TNode trigger) { +void EqualityEngine::addTriggerEqualityInternal(TNode t1, TNode t2, TNode trigger, bool polarity) { Debug("equality") << "EqualityEngine::addTrigger(" << t1 << ", " << t2 << ", " << trigger << ")" << std::endl; @@ -713,9 +764,9 @@ void EqualityEngine<NotifyClass>::addTriggerEquality(TNode t1, TNode t2, TNode t TriggerId t1NewTriggerId = d_equalityTriggers.size(); TriggerId t2NewTriggerId = t1NewTriggerId | 1; d_equalityTriggers.push_back(Trigger(t1classId, t1TriggerId)); - d_equalityTriggersOriginal.push_back(trigger); + d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity)); d_equalityTriggers.push_back(Trigger(t2classId, t2TriggerId)); - d_equalityTriggersOriginal.push_back(trigger); + d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity)); // Update the counters d_equalityTriggersCount = d_equalityTriggersCount + 2; @@ -728,7 +779,7 @@ void EqualityEngine<NotifyClass>::addTriggerEquality(TNode t1, TNode t2, TNode t if (t1classId == t2classId) { Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << "): triggered at setup time" << std::endl; if (d_performNotify) { - d_notify.notify(trigger); // Don't care about the return value + d_notify.eqNotifyTriggerEquality(trigger, polarity); // Don't care about the return value } } @@ -739,8 +790,7 @@ void EqualityEngine<NotifyClass>::addTriggerEquality(TNode t1, TNode t2, TNode t Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ") => (" << t1NewTriggerId << ", " << t2NewTriggerId << ")" << std::endl; } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::propagate() { +void EqualityEngine::propagate() { Debug("equality") << "EqualityEngine::propagate()" << std::endl; @@ -783,25 +833,29 @@ void EqualityEngine<NotifyClass>::propagate() { if (node2.getSize() > node1.getSize()) { Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t1Id]<< " into " << d_nodes[current.t2Id] << std::endl; d_assertedEqualities.push_back(Equality(t2classId, t1classId)); - merge(node2, node1, triggers); + done = !merge(node2, node1, triggers); } else { Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t2Id] << " into " << d_nodes[current.t1Id] << std::endl; d_assertedEqualities.push_back(Equality(t1classId, t2classId)); - merge(node1, node2, triggers); + done = !merge(node1, node2, triggers); } // Notify the triggers - if (d_performNotify) { - for (size_t trigger = 0, trigger_end = triggers.size(); trigger < trigger_end && !done; ++ trigger) { + if (d_performNotify && !done) { + for (size_t trigger_i = 0, trigger_end = triggers.size(); trigger_i < trigger_end && !done; ++ trigger_i) { + const TriggerInfo& triggerInfo = d_equalityTriggersOriginal[triggers[trigger_i]]; // Notify the trigger and exit if it fails - done = !d_notify.notify(d_equalityTriggersOriginal[triggers[trigger]]); + if (triggerInfo.trigger.getKind() == kind::EQUAL) { + done = !d_notify.eqNotifyTriggerEquality(triggerInfo.trigger, triggerInfo.polarity); + } else { + done = !d_notify.eqNotifyTriggerPredicate(triggerInfo.trigger, triggerInfo.polarity); + } } } } } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::debugPrintGraph() const { +void EqualityEngine::debugPrintGraph() const { for (EqualityNodeId nodeId = 0; nodeId < d_nodes.size(); ++ nodeId) { Debug("equality::graph") << d_nodes[nodeId] << " " << nodeId << "(" << getEqualityNode(nodeId).getFind() << "):"; @@ -817,11 +871,10 @@ void EqualityEngine<NotifyClass>::debugPrintGraph() const { } } -template <typename NotifyClass> -bool EqualityEngine<NotifyClass>::areEqual(TNode t1, TNode t2) +bool EqualityEngine::areEqual(TNode t1, TNode t2) { // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); + ScopedBool turnOffNotify(d_performNotify, false); // Add the terms addTerm(t1); @@ -832,17 +885,18 @@ bool EqualityEngine<NotifyClass>::areEqual(TNode t1, TNode t2) return equal; } -template <typename NotifyClass> -bool EqualityEngine<NotifyClass>::areDisequal(TNode t1, TNode t2) +bool EqualityEngine::areDisequal(TNode t1, TNode t2) { // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); + ScopedBool turnOffNotify(d_performNotify, false); // Add the terms addTerm(t1); addTerm(t2); // Check (t1 = t2) = false + // No need to check the symmetric version: we can only deduce a disequality from an existing + // diseqality, and each of those is asserted in the symmetric version also Node equality = t1.eqNode(t2); addTerm(equality); if (getEqualityNode(equality).getFind() == getEqualityNode(d_false).getFind()) { @@ -853,16 +907,14 @@ bool EqualityEngine<NotifyClass>::areDisequal(TNode t1, TNode t2) return false; } -template <typename NotifyClass> -size_t EqualityEngine<NotifyClass>::getSize(TNode t) +size_t EqualityEngine::getSize(TNode t) { // Add the term addTerm(t); return getEqualityNode(getEqualityNode(t).getFind()).getSize(); } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::addTriggerTerm(TNode t) +void EqualityEngine::addTriggerTerm(TNode t) { Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ")" << std::endl; @@ -877,7 +929,7 @@ void EqualityEngine<NotifyClass>::addTriggerTerm(TNode t) if (d_nodeIndividualTrigger[classId] != +null_id) { // No need to keep it, just propagate the existing individual triggers if (d_performNotify) { - d_notify.notify(t, d_nodes[d_nodeIndividualTrigger[classId]]); + d_notify.eqNotifyTriggerTermEquality(t, d_nodes[d_nodeIndividualTrigger[classId]], true); } } else { // Add it to the list for backtracking @@ -888,23 +940,20 @@ void EqualityEngine<NotifyClass>::addTriggerTerm(TNode t) } } -template <typename NotifyClass> -bool EqualityEngine<NotifyClass>::isTriggerTerm(TNode t) const { +bool EqualityEngine::isTriggerTerm(TNode t) const { if (!hasTerm(t)) return false; EqualityNodeId classId = getEqualityNode(t).getFind(); return d_nodeIndividualTrigger[classId] != +null_id; } -template <typename NotifyClass> -TNode EqualityEngine<NotifyClass>::getTriggerTermRepresentative(TNode t) const { +TNode EqualityEngine::getTriggerTermRepresentative(TNode t) const { Assert(isTriggerTerm(t)); EqualityNodeId classId = getEqualityNode(t).getFind(); return d_nodes[d_nodeIndividualTrigger[classId]]; } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::storeApplicationLookup(FunctionApplication& funNormalized, EqualityNodeId funId) { +void EqualityEngine::storeApplicationLookup(FunctionApplication& funNormalized, EqualityNodeId funId) { Assert(d_applicationLookup.find(funNormalized) == d_applicationLookup.end()); d_applicationLookup[funNormalized] = funId; d_applicationLookups.push_back(funNormalized); @@ -914,8 +963,7 @@ void EqualityEngine<NotifyClass>::storeApplicationLookup(FunctionApplication& fu Assert(d_applicationLookupsCount == d_applicationLookups.size()); } -template <typename NotifyClass> -void EqualityEngine<NotifyClass>::getUseListTerms(TNode t, std::set<TNode>& output) { +void EqualityEngine::getUseListTerms(TNode t, std::set<TNode>& output) { if (hasTerm(t)) { // Get the equivalence class EqualityNodeId classId = getEqualityNode(t).getFind(); diff --git a/src/theory/uf/equality_engine.h b/src/theory/uf/equality_engine.h index dccd5ba56..f9c10d1b6 100644 --- a/src/theory/uf/equality_engine.h +++ b/src/theory/uf/equality_engine.h @@ -35,7 +35,7 @@ namespace CVC4 { namespace theory { -namespace uf { +namespace eq { /** Id of the node */ typedef size_t EqualityNodeId; @@ -213,9 +213,74 @@ public: } }; -template <typename NotifyClass> +/** + * Interface for equality engine notifications. All the notifications + * are safe as TNodes, but not necessarily for negations. + */ +class EqualityEngineNotify { + + friend class EqualityEngine; + +public: + + virtual ~EqualityEngineNotify() {}; + + /** + * Notifies about a trigger equality that became true or false. + * + * @param eq the equality that became true or false + * @param value the value of the equality + */ + virtual bool eqNotifyTriggerEquality(TNode equality, bool value) = 0; + + /** + * Notifies about a trigger predicate that became true or false. + * + * @param predicate the trigger predicate that bacame true or false + * @param value the value of the predicate + */ + virtual bool eqNotifyTriggerPredicate(TNode predicate, bool value) = 0; + + /** + * Notifies about the merge of two trigger terms. + * + * @param t1 a term marked as trigger + * @param t2 a term marked as trigger + * @param value true if equal, false if dis-equal + */ + virtual bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) = 0; + + /** + * Notifies about the merge of two constant terms. + * + * @param t1 a constant term + * @param t2 a constnat term + */ + virtual bool eqNotifyConstantTermMerge(TNode t1, TNode t2) = 0; +}; + +/** + * Implementation of the notification interface that ignores all the + * notifications. + */ +class EqualityEngineNotifyNone : public EqualityEngineNotify { +public: + bool eqNotifyTriggerEquality(TNode equality, bool value) { return true; } + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { return true; } + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { return true; } + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { return true; } +}; + + +/** + * Class for keeping an incremental congurence closure over a set of terms. It provides + * notifications via an EqualityEngineNotify object. + */ class EqualityEngine : public context::ContextNotifyObj { + /** Default implementation of the notification object */ + static EqualityEngineNotifyNone s_notifyNone; + public: /** Statistics about the equality engine instance */ @@ -226,21 +291,26 @@ public: IntStat termsCount; /** Number of function terms managed by the system */ IntStat functionTermsCount; + /** Number of constant terms managed by the system */ + IntStat constantTermsCount; Statistics(std::string name) : mergesCount(name + "::mergesCount", 0), termsCount(name + "::termsCount", 0), - functionTermsCount(name + "::functionTermsCount", 0) + functionTermsCount(name + "::functionTermsCount", 0), + constantTermsCount(name + "::constantTermsCount", 0) { StatisticsRegistry::registerStat(&mergesCount); StatisticsRegistry::registerStat(&termsCount); StatisticsRegistry::registerStat(&functionTermsCount); + StatisticsRegistry::registerStat(&constantTermsCount); } ~Statistics() { StatisticsRegistry::unregisterStat(&mergesCount); StatisticsRegistry::unregisterStat(&termsCount); StatisticsRegistry::unregisterStat(&functionTermsCount); + StatisticsRegistry::unregisterStat(&constantTermsCount); } }; @@ -282,7 +352,7 @@ private: bool d_performNotify; /** The class to notify when a representative changes for a term */ - NotifyClass d_notify; + EqualityEngineNotify& d_notify; /** The map of kinds to be treated as function applications */ KindMap d_congruenceKinds; @@ -428,8 +498,11 @@ private: /** Returns the id of the node */ EqualityNodeId getNodeId(TNode node) const; - /** Merge the class2 into class1 */ - void merge(EqualityNode& class1, EqualityNode& class2, std::vector<TriggerId>& triggers); + /** + * Merge the class2 into class1 + * @return true if ok, false if to break out + */ + bool merge(EqualityNode& class1, EqualityNode& class2, std::vector<TriggerId>& triggers); /** Undo the mereg of class2 into class1 */ void undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id); @@ -438,28 +511,12 @@ private: void backtrack(); /** - * Data used in the BFS search through the equality graph. - */ - struct BfsData { - // The current node - EqualityNodeId nodeId; - // The index of the edge we traversed - EqualityEdgeId edgeId; - // Index in the queue of the previous node. Shouldn't be too much of them, at most the size - // of the biggest equivalence class - size_t previousIndex; - - BfsData(EqualityNodeId nodeId = null_id, EqualityEdgeId edgeId = null_edge, size_t prev = 0) - : nodeId(nodeId), edgeId(edgeId), previousIndex(prev) {} - }; - - /** * Trigger that will be updated */ struct Trigger { /** The current class id of the LHS of the trigger */ EqualityNodeId classId; - /** Next trigger for class 1 */ + /** Next trigger for class */ TriggerId nextTrigger; Trigger(EqualityNodeId classId = null_id, TriggerId nextTrigger = null_trigger) @@ -473,10 +530,20 @@ private: */ std::vector<Trigger> d_equalityTriggers; + struct TriggerInfo { + /** The trigger itself */ + Node trigger; + /** Polarity of the trigger */ + bool polarity; + TriggerInfo() {} + TriggerInfo(Node trigger, bool polarity) + : trigger(trigger), polarity(polarity) {} + }; + /** * Vector of original equalities of the triggers. */ - std::vector<Node> d_equalityTriggersOriginal; + std::vector<TriggerInfo> d_equalityTriggersOriginal; /** * Context dependent count of triggers @@ -505,6 +572,19 @@ private: std::vector<EqualityNodeId> d_nodeIndividualTrigger; /** + * Map from ids to the id of the constant that is the representative. + */ + std::vector<EqualityNodeId> d_constantRepresentative; + + /** + * Size of the constant representatives list. + */ + context::CDO<unsigned> d_constantRepresentativesSize; + + /** The list of representatives that became constant. */ + std::vector<EqualityNodeId> d_constantRepresentatives; + + /** * Adds the trigger with triggerId to the beginning of the trigger list of the node with id nodeId. */ void addTriggerToList(EqualityNodeId nodeId, TriggerId triggerId); @@ -516,7 +596,7 @@ private: EqualityNodeId newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2); /** Add a new node to the database */ - EqualityNodeId newNode(TNode t, bool isApplication); + EqualityNodeId newNode(TNode t); struct MergeCandidate { EqualityNodeId t1Id, t2Id; @@ -561,44 +641,41 @@ private: /** * Adds an equality of terms t1 and t2 to the database. */ - void addEqualityInternal(TNode t1, TNode t2, TNode reason); + void assertEqualityInternal(TNode t1, TNode t2, TNode reason); -public: + /** + * Adds a trigger equality to the database with the trigger node and polarity for notification. + */ + void addTriggerEqualityInternal(TNode t1, TNode t2, TNode trigger, bool polarity); /** - * Initialize the equality engine, given the owning class. This will initialize the notifier with - * the owner information. - */ - EqualityEngine(NotifyClass& notify, context::Context* context, std::string name) - : ContextNotifyObj(context), - d_context(context), - d_performNotify(true), - d_notify(notify), - d_applicationLookupsCount(context, 0), - d_nodesCount(context, 0), - d_assertedEqualitiesCount(context, 0), - d_equalityTriggersCount(context, 0), - d_individualTriggersSize(context, 0), - d_stats(name) - { - Debug("equality") << "EqualityEdge::EqualityEngine(): id_null = " << +null_id << std::endl; - Debug("equality") << "EqualityEdge::EqualityEngine(): edge_null = " << +null_edge << std::endl; - Debug("equality") << "EqualityEdge::EqualityEngine(): trigger_null = " << +null_trigger << std::endl; - d_true = NodeManager::currentNM()->mkConst<bool>(true); - d_false = NodeManager::currentNM()->mkConst<bool>(false); + * This method gets called on backtracks from the context manager. + */ + void contextNotifyPop() { + backtrack(); } /** - * Just a destructor. + * Constructor initialization stuff. */ - virtual ~EqualityEngine() throw(AssertionException) {} + void init(); + +public: /** - * This method gets called on backtracks from the context manager. + * Initialize the equality engine, given the notification class. */ - void notify() { - backtrack(); - } + EqualityEngine(EqualityEngineNotify& notify, context::Context* context, std::string name); + + /** + * Initialize the equality engine with no notification class. + */ + EqualityEngine(context::Context* context, std::string name); + + /** + * Just a destructor. + */ + virtual ~EqualityEngine() throw(AssertionException) {} /** * Adds a term to the term database. @@ -629,77 +706,91 @@ public: bool hasTerm(TNode t) const; /** - * Adds aa predicate t with given polarity + * Adds a predicate p with given polarity. The predicate asserted + * should be in the coungruence closure kinds (otherwise it's + * useless. + * + * @param p the (non-negated) predicate + * @param polarity true if asserting the predicate, false if + * asserting the negated predicate + * @param the reason to keep for building explanations */ - void addPredicate(TNode t, bool polarity, TNode reason); + void assertPredicate(TNode p, bool polarity, TNode reason); /** - * Adds an equality t1 = t2 to the database. + * Adds an equality eq with the given polarity to the database. + * + * @param eq the (non-negated) equality + * @param polarity true if asserting the equality, false if + * asserting the negated equality + * @param the reason to keep for building explanations */ - void addEquality(TNode t1, TNode t2, TNode reason); + void assertEquality(TNode eq, bool polarity, TNode reason); /** - * Adds an dis-equality t1 != t2 to the database. - */ - void addDisequality(TNode t1, TNode t2, TNode reason); - - /** - * Returns the representative of the term t. + * Returns the current representative of the term t. */ TNode getRepresentative(TNode t) const; /** - * Add all the terms where the given term appears in (directly or implicitly). + * Add all the terms where the given term appears as a first child + * (directly or implicitly). */ void getUseListTerms(TNode t, std::set<TNode>& output); /** - * Returns true if the two nodes are in the same class. + * Returns true if the two nodes are in the same equivalence class. */ bool areEqual(TNode t1, TNode t2) const; /** - * Get an explanation of the equality t1 = t2. Returns the asserted equalities that - * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere - * else. + * Get an explanation of the equality t1 = t2 begin true of false. + * Returns the reasons (added when asserting) that imply it + * in the assertions vector. */ - void explainEquality(TNode t1, TNode t2, std::vector<TNode>& equalities); + void explainEquality(TNode t1, TNode t2, bool polarity, std::vector<TNode>& assertions); /** - * Get an explanation of the equality t1 = t2. Returns the asserted equalities that - * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere - * else. + * Get an explanation of the predicate being true or false. + * Returns the reasons (added when asserting) that imply imply it + * in the assertions vector. */ - void explainDisequality(TNode t1, TNode t2, std::vector<TNode>& equalities); + void explainPredicate(TNode p, bool polarity, std::vector<TNode>& assertions); /** - * Add term to the trigger terms. The notify class will get notified when two - * trigger terms become equal. Thihs will only happen on trigger term - * representatives. + * Add term to the trigger terms. The notify class will get notified + * when two trigger terms become equal or dis-equal. The notification + * will not happen on all the terms, but only on the ones that are + * represent the class. */ void addTriggerTerm(TNode t); /** - * Returns true if t is a trigger term or equal to some other trigger term. + * Returns true if t is a trigger term or in the same equivalence + * class as some other trigger term. */ bool isTriggerTerm(TNode t) const; /** - * Returns the representative trigger term (isTriggerTerm(t)) should be true. + * Returns the representative trigger term of the given term. + * + * @param t the term to check where isTriggerTerm(t) should be true */ TNode getTriggerTermRepresentative(TNode t) const; /** - * Adds a notify trigger for equality t1 = t2, i.e. when t1 = t2 the notify will be called with - * trigger. + * Adds a notify trigger for equality. When equality becomes true eqNotifyTriggerEquality + * will be called with value = true, and when equality becomes false eqNotifyTriggerEquality + * will be called with value = false. */ - void addTriggerEquality(TNode t1, TNode t2, TNode trigger); + void addTriggerEquality(TNode equality); /** - * Adds a notify trigger for dis-equality t1 != t2, i.e. when t1 != t2 the notify will be called with - * trigger. + * Adds a notify trigger for the predicate p. When the predicate becomes true + * eqNotifyTriggerPredicate will be called with value = true, and when equality becomes false + * eqNotifyTriggerPredicate will be called with value = false. */ - void addTriggerDisequality(TNode t1, TNode t2, TNode trigger); + void addTriggerPredicate(TNode predicate); /** * Check whether the two terms are equal. @@ -712,7 +803,7 @@ public: bool areDisequal(TNode t1, TNode t2); /** - * Return the number of nodes in the equivalence class contianing t + * Return the number of nodes in the equivalence class containing t * Adds t if not already there. */ size_t getSize(TNode t); diff --git a/src/theory/uf/theory_uf.cpp b/src/theory/uf/theory_uf.cpp index ec28dad75..cddd01b07 100644 --- a/src/theory/uf/theory_uf.cpp +++ b/src/theory/uf/theory_uf.cpp @@ -18,13 +18,11 @@ **/ #include "theory/uf/theory_uf.h" -#include "theory/uf/equality_engine_impl.h" using namespace std; - -namespace CVC4 { -namespace theory { -namespace uf { +using namespace CVC4; +using namespace CVC4::theory; +using namespace CVC4::theory::uf; /** Constructs a new instance of TheoryUF w.r.t. the provided context.*/ TheoryUF::TheoryUF(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo) : @@ -40,12 +38,6 @@ TheoryUF::TheoryUF(context::Context* c, context::UserContext* u, OutputChannel& d_equalityEngine.addFunctionKind(kind::APPLY_UF); d_equalityEngine.addFunctionKind(kind::EQUAL); - // The boolean constants - d_true = NodeManager::currentNM()->mkConst<bool>(true); - d_false = NodeManager::currentNM()->mkConst<bool>(false); - d_equalityEngine.addTerm(d_true); - d_equalityEngine.addTerm(d_false); - d_equalityEngine.addTriggerEquality(d_true, d_false, d_false); }/* TheoryUF::TheoryUF() */ static Node mkAnd(const std::vector<TNode>& conjunctions) { @@ -91,23 +83,12 @@ void TheoryUF::check(Effort level) { } // Do the work - switch (fact.getKind()) { - case kind::EQUAL: - d_equalityEngine.addEquality(fact[0], fact[1], fact); - break; - case kind::APPLY_UF: - d_equalityEngine.addPredicate(fact, true, fact); - break; - case kind::NOT: - if (fact[0].getKind() == kind::APPLY_UF) { - d_equalityEngine.addPredicate(fact[0], false, fact); - } else { - // Assert the dis-equality - d_equalityEngine.addDisequality(fact[0][0], fact[0][1], fact); - } - break; - default: - Unreachable(); + bool polarity = fact.getKind() != kind::NOT; + TNode atom = polarity ? fact : fact[0]; + if (atom.getKind() == kind::EQUAL) { + d_equalityEngine.assertEquality(atom, polarity, fact); + } else { + d_equalityEngine.assertPredicate(atom, polarity, fact); } } @@ -139,10 +120,8 @@ void TheoryUF::propagate(Effort level) { Debug("uf") << "TheoryUF::propagate(): in conflict, normalized = " << normalized << std::endl; Node negatedLiteral; std::vector<TNode> assumptions; - if (normalized != d_false) { - negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); - assumptions.push_back(negatedLiteral); - } + negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); + assumptions.push_back(negatedLiteral); explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; @@ -157,21 +136,17 @@ void TheoryUF::preRegisterTerm(TNode node) { switch (node.getKind()) { case kind::EQUAL: - // Add the terms - d_equalityEngine.addTerm(node[0]); - d_equalityEngine.addTerm(node[1]); // Add the trigger for equality - d_equalityEngine.addTriggerEquality(node[0], node[1], node); - d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode()); + d_equalityEngine.addTriggerEquality(node); break; case kind::APPLY_UF: - // Function applications/predicates - d_equalityEngine.addTerm(node); // Maybe it's a predicate if (node.getType().isBoolean()) { // Get triggered for both equal and dis-equal - d_equalityEngine.addTriggerEquality(node, d_true, node); - d_equalityEngine.addTriggerEquality(node, d_false, node.notNode()); + d_equalityEngine.addTriggerPredicate(node); + } else { + // Function applications/predicates + d_equalityEngine.addTerm(node); } // Remember the function and predicate terms d_functionsTerms.push_back(node); @@ -194,26 +169,20 @@ bool TheoryUF::propagate(TNode literal) { // See if the literal has been asserted already Node normalized = Rewriter::rewrite(literal); - bool satValue = false; - bool isAsserted = normalized == d_false || d_valuation.hasSatValue(normalized, satValue); - // If asserted, we're done or in conflict - if (isAsserted) { - if (!satValue) { + // If asserted and is false, we're done or in conflict + // Note that even trivial equalities have a SAT value (i.e. 1 = 2 -> false) + bool satValue = false; + if (d_valuation.hasSatValue(normalized, satValue) && !satValue) { Debug("uf") << "TheoryUF::propagate(" << literal << ", normalized = " << normalized << ") => conflict" << std::endl; std::vector<TNode> assumptions; Node negatedLiteral; - if (normalized != d_false) { - negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); - assumptions.push_back(negatedLiteral); - } + negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); + assumptions.push_back(negatedLiteral); explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; return false; - } - // Propagate even if already known in SAT - could be a new equation between shared terms - // (terms that weren't shared when the literal was asserted!) } // Nothing, just enqueue it for propagation and mark it as asserted already @@ -224,36 +193,14 @@ bool TheoryUF::propagate(TNode literal) { }/* TheoryUF::propagate(TNode) */ void TheoryUF::explain(TNode literal, std::vector<TNode>& assumptions) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::APPLY_UF: - lhs = literal; - rhs = d_true; - break; - case kind::NOT: - if (literal[0].getKind() == kind::EQUAL) { - // Disequalities - d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions); - return; - } else { - // Predicates - lhs = literal[0]; - rhs = d_false; - break; - } - case kind::CONST_BOOLEAN: - // we get to explain true = false, since we set false to be the trigger of this - lhs = d_true; - rhs = d_false; - break; - default: - Unreachable(); + // Do the work + bool polarity = literal.getKind() != kind::NOT; + TNode atom = polarity ? literal : literal[0]; + if (atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) { + d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + } else { + d_equalityEngine.explainPredicate(atom, polarity, assumptions); } - d_equalityEngine.explainEquality(lhs, rhs, assumptions); } Node TheoryUF::explain(TNode literal) { @@ -508,7 +455,3 @@ void TheoryUF::computeCareGraph() { } } }/* TheoryUF::computeCareGraph() */ - -}/* CVC4::theory::uf namespace */ -}/* CVC4::theory namespace */ -}/* CVC4 namespace */ diff --git a/src/theory/uf/theory_uf.h b/src/theory/uf/theory_uf.h index 6956390f5..9017928b7 100644 --- a/src/theory/uf/theory_uf.h +++ b/src/theory/uf/theory_uf.h @@ -39,21 +39,46 @@ namespace uf { class TheoryUF : public Theory { public: - class NotifyClass { + class NotifyClass : public eq::EqualityEngineNotify { TheoryUF& d_uf; public: NotifyClass(TheoryUF& uf): d_uf(uf) {} - bool notify(TNode propagation) { - Debug("uf") << "NotifyClass::notify(" << propagation << ")" << std::endl; - // Just forward to uf - return d_uf.propagate(propagation); + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Debug("uf") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_uf.propagate(equality); + } else { + // We use only literal triggers so taking not is safe + return d_uf.propagate(equality.notNode()); + } } - - void notify(TNode t1, TNode t2) { - Debug("uf") << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl; - Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2)); - d_uf.propagate(equality); + + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Debug("uf") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_uf.propagate(predicate); + } else { + return d_uf.propagate(predicate.notNode()); + } + } + + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + Debug("uf") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << std::endl; + if (value) { + return d_uf.propagate(t1.eqNode(t2)); + } else { + return d_uf.propagate(t1.eqNode(t2).notNode()); + } + } + + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + Debug("uf") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl; + if (Theory::theoryOf(t1) == THEORY_BOOL) { + return d_uf.propagate(t1.iffNode(t2)); + } else { + return d_uf.propagate(t1.eqNode(t2)); + } } }; @@ -63,7 +88,7 @@ private: NotifyClass d_notify; /** Equaltity engine */ - EqualityEngine<NotifyClass> d_equalityEngine; + eq::EqualityEngine d_equalityEngine; /** Are we in conflict */ context::CDO<bool> d_conflict; @@ -72,7 +97,8 @@ private: Node d_conflictNode; /** - * Should be called to propagate the literal. + * Should be called to propagate the literal. We use a node here + * since some of the propagated literals are not kept anywhere. */ bool propagate(TNode literal); @@ -90,12 +116,6 @@ private: /** All the function terms that the theory has seen */ context::CDList<TNode> d_functionsTerms; - /** True node for predicates = true */ - Node d_true; - - /** True node for predicates = false */ - Node d_false; - /** Symmetry analyzer */ SymmetryBreaker d_symb; |