summaryrefslogtreecommitdiff
path: root/src/theory/uf/equality_engine.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/uf/equality_engine.cpp')
-rw-r--r--src/theory/uf/equality_engine.cpp170
1 files changed, 135 insertions, 35 deletions
diff --git a/src/theory/uf/equality_engine.cpp b/src/theory/uf/equality_engine.cpp
index 5093e5a7b..4cd54a2bf 100644
--- a/src/theory/uf/equality_engine.cpp
+++ b/src/theory/uf/equality_engine.cpp
@@ -62,14 +62,14 @@ void EqualityEngine::init() {
d_true = NodeManager::currentNM()->mkConst<bool>(true);
d_false = NodeManager::currentNM()->mkConst<bool>(false);
+ d_triggerDatabaseAllocatedSize = 100000;
+ d_triggerDatabase = (char*) malloc(d_triggerDatabaseAllocatedSize);
+
addTerm(d_true);
addTerm(d_false);
d_trueId = getNodeId(d_true);
d_falseId = getNodeId(d_false);
-
- d_triggerDatabaseAllocatedSize = 100000;
- d_triggerDatabase = (char*) malloc(d_triggerDatabaseAllocatedSize);
}
EqualityEngine::~EqualityEngine() throw(AssertionException) {
@@ -114,7 +114,8 @@ EqualityEngine::EqualityEngine(EqualityEngineNotify& notify, context::Context* c
}
void EqualityEngine::enqueue(const MergeCandidate& candidate) {
- d_propagationQueue.push(candidate);
+ Debug("equality") << "EqualityEngine::enqueue(" << d_nodes[candidate.t1Id] << ", " << d_nodes[candidate.t2Id] << ", " << candidate.type << ")" << std::endl;
+ d_propagationQueue.push(candidate);
}
EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2, bool isEquality) {
@@ -144,7 +145,15 @@ EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId
if (isEquality && d_isConstant[t1ClassId] && d_isConstant[t2ClassId]) {
if (t1ClassId != t2ClassId) {
Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): got constants" << std::endl;
+ Assert(d_nodes[funId].getKind() == kind::EQUAL);
enqueue(MergeCandidate(funId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null()));
+ // Also enqueue the symmetric one
+ TNode eq = d_nodes[funId];
+ Node symmetricEq = eq[1].eqNode(eq[0]);
+ if (hasTerm(symmetricEq)) {
+ EqualityNodeId symmFunId = getNodeId(symmetricEq);
+ enqueue(MergeCandidate(symmFunId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null()));
+ }
}
}
} else {
@@ -154,8 +163,8 @@ EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId
}
// Add to the use lists
- d_equalityNodes[t1ClassId].usedIn<USE_LIST_FUNCTIONS>(funId, d_useListNodes);
- d_equalityNodes[t2ClassId].usedIn<USE_LIST_FUNCTIONS>(funId, d_useListNodes);
+ d_equalityNodes[t1].usedIn(funId, d_useListNodes);
+ d_equalityNodes[t2].usedIn(funId, d_useListNodes);
// Return the new id
Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ") => " << funId << std::endl;
@@ -238,6 +247,20 @@ void EqualityEngine::addTerm(TNode t) {
// We set this here as this only applies to actual terms, not the
// intermediate application terms
d_isBoolean[result] = true;
+ } else if (t.isConst()) {
+ // Non-Boolean constants are trigger terms for all tags
+ EqualityNodeId tId = getNodeId(t);
+ d_newSetTags = 0;
+ d_newSetTriggersSize = THEORY_LAST;
+ for (TheoryId currentTheory = THEORY_FIRST; currentTheory != THEORY_LAST; ++ currentTheory) {
+ d_newSetTags = Theory::setInsert(currentTheory, d_newSetTags);
+ d_newSetTriggers[currentTheory] = tId;
+ }
+ // Add it to the list for backtracking
+ d_triggerTermSetUpdates.push_back(TriggerSetUpdate(tId, null_set_id));
+ d_triggerTermSetUpdatesSize = d_triggerTermSetUpdatesSize + 1;
+ // Mark the the new set as a trigger
+ d_nodeIndividualTrigger[tId] = newTriggerTermSet();
}
propagate();
@@ -319,15 +342,20 @@ void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason) {
assertEqualityInternal(eq, d_false, reason);
propagate();
assertEqualityInternal(eq[1].eqNode(eq[0]), d_false, reason);
- propagate();
+ propagate();
if (d_done) {
return;
}
- // If we are adding a disequality, notify of the shared term representatives
+ // If both have constant representatives, we don't notify anyone
EqualityNodeId a = getNodeId(eq[0]);
EqualityNodeId b = getNodeId(eq[1]);
+ if (isConstant(a) && isConstant(b)) {
+ return;
+ }
+
+ // If we are adding a disequality, notify of the shared term representatives
EqualityNodeId eqId = getNodeId(eq);
TriggerTermSetRef aTriggerRef = d_nodeIndividualTrigger[a];
TriggerTermSetRef bTriggerRef = d_nodeIndividualTrigger[b];
@@ -356,6 +384,7 @@ void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason) {
d_deducedDisequalityReasons.push_back(EqualityPair(bSharedId, b));
d_deducedDisequalityReasons.push_back(EqualityPair(eqId, d_falseId));
storePropagatedDisequality(d_nodes[aSharedId], d_nodes[bSharedId], 3);
+
// We notify even if the it's already been sent (they are not
// disequal at assertion, and we need to notify for each tag)
if (!d_notify.eqNotifyTriggerTermEquality(aTag, d_nodes[aSharedId], d_nodes[bSharedId], false)) {
@@ -383,7 +412,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl;
Assert(triggersFired.empty());
-
+
++ d_stats.mergesCount;
EqualityNodeId class1Id = class1.getFind();
@@ -391,6 +420,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
// Check for constant merges
bool isConstant = d_isConstant[class1Id];
+ Assert(isConstant || !d_isConstant[class2Id]);
// Update class2 representative information
Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating class " << class2Id << std::endl;
@@ -438,13 +468,13 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of node " << currentId << std::endl;
// Go through the uselist and check for congruences
- UseListNodeId currentUseId = currentNode.getUseList<USE_LIST_FUNCTIONS>();
+ UseListNodeId currentUseId = currentNode.getUseList();
while (currentUseId != null_uselist_id) {
// Get the node of the use list
UseListNode& useNode = d_useListNodes[currentUseId];
// Get the function application
EqualityNodeId funId = useNode.getApplicationId();
- Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << currentId << " in " << d_nodes[funId] << std::endl;
+ Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << d_nodes[currentId] << " in " << d_nodes[funId] << std::endl;
const FunctionApplication& fun = d_applications[useNode.getApplicationId()].normalized;
// Check if there is an application with find arguments
EqualityNodeId aNormalized = getEqualityNode(fun.a).getFind();
@@ -460,11 +490,16 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
// There is no representative, so we can add one, we remove this when backtracking
storeApplicationLookup(funNormalized, funId);
// Now, if we're constant and it's an equality, check if the other guy is also a constant
- if (isConstant && funNormalized.isEquality) {
- if (d_isConstant[funNormalized.a] && d_isConstant[funNormalized.b]) {
- // both constants
- if (funNormalized.a != funNormalized.b) {
- enqueue(MergeCandidate(funId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null()));
+ if (funNormalized.isEquality) {
+ if (d_isConstant[aNormalized] && d_isConstant[bNormalized] && aNormalized != bNormalized) {
+ Assert(d_nodes[funId].getKind() == kind::EQUAL);
+ enqueue(MergeCandidate(funId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null()));
+ // Also enqueue the symmetric one
+ TNode eq = d_nodes[funId];
+ Node symmetricEq = eq[1].eqNode(eq[0]);
+ if (hasTerm(symmetricEq)) {
+ EqualityNodeId symmFunId = getNodeId(symmetricEq);
+ enqueue(MergeCandidate(symmFunId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null()));
}
}
}
@@ -482,6 +517,17 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
// Now merge the lists
class1.merge<true>(class2);
+ // Notify of the constants merge
+ bool constantMerge = false;
+ if (isConstant && d_isConstant[class2Id]) {
+ constantMerge = true;
+ if (d_performNotify) {
+ if (!d_notify.eqNotifyConstantTermMerge(d_nodes[class1Id], d_nodes[class2Id])) {
+ return false;
+ }
+ }
+ }
+
// Go through the triggers and store the dis-equalities
unsigned i = 0, j = 0;
for (; i < triggersFired.size();) {
@@ -543,7 +589,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
// copy tag1
EqualityNodeId tag1id = d_newSetTriggers[d_newSetTriggersSize++] = class1triggers.triggers[i1++];
// since they are both tagged notify of merge
- if (d_performNotify) {
+ if (d_performNotify && !constantMerge) {
EqualityNodeId tag2id = class2triggers.triggers[i2++];
if (!d_notify.eqNotifyTriggerTermEquality(tag1, d_nodes[tag1id], d_nodes[tag2id], true)) {
return false;
@@ -566,15 +612,6 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
}
}
- // Notify of the constants merge
- if (isConstant && d_isConstant[class2Id]) {
- if (d_performNotify) {
- if (!d_notify.eqNotifyConstantTermMerge(d_nodes[class1Id], d_nodes[class2Id])) {
- return false;
- }
- }
- }
-
// Everything fine
return true;
}
@@ -680,12 +717,12 @@ void EqualityEngine::backtrack() {
Debug("equality") << "EqualityEngine::backtrack(): removing node " << d_nodes[i] << std::endl;
d_nodeIds.erase(d_nodes[i]);
- const FunctionApplication& app = d_applications[i].normalized;
+ const FunctionApplication& app = d_applications[i].original;
if (app.isApplication()) {
// Remove b from use-list
- getEqualityNode(app.b).removeTopFromUseList<USE_LIST_FUNCTIONS>(d_useListNodes);
+ getEqualityNode(app.b).removeTopFromUseList(d_useListNodes);
// Remove a from use-list
- getEqualityNode(app.a).removeTopFromUseList<USE_LIST_FUNCTIONS>(d_useListNodes);
+ getEqualityNode(app.a).removeTopFromUseList(d_useListNodes);
}
}
@@ -818,8 +855,10 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, st
// Go through the equality edges of this node
EqualityEdgeId currentEdge = d_equalityGraph[currentNode];
- Debug("equality") << "EqualityEngine::getExplanation(): edgesId = " << currentEdge << std::endl;
- Debug("equality") << "EqualityEngine::getExplanation(): edges = " << edgesToString(currentEdge) << std::endl;
+ if (Debug.isOn("equality")) {
+ Debug("equality") << "EqualityEngine::getExplanation(): edgesId = " << currentEdge << std::endl;
+ Debug("equality") << "EqualityEngine::getExplanation(): edges = " << edgesToString(currentEdge) << std::endl;
+ }
while (currentEdge != null_edge) {
// Get the edge
@@ -871,8 +910,10 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, st
Debug("equality") << push;
// Explain why a is a constant
+ Assert(isConstant(eq.a));
getExplanation(eq.a, getEqualityNode(eq.a).getFind(), equalities);
// Explain why b is a constant
+ Assert(isConstant(eq.b));
getExplanation(eq.b, getEqualityNode(eq.b).getFind(), equalities);
Debug("equality") << pop;
@@ -1062,7 +1103,7 @@ void EqualityEngine::propagate() {
// Depending on the merge preference (such as size, or being a constant), merge them
std::vector<TriggerId> triggers;
- if (node2.getSize() > node1.getSize() || d_isConstant[t2classId]) {
+ if ((node2.getSize() > node1.getSize() && !d_isConstant[t1classId]) || d_isConstant[t2classId]) {
Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t1Id]<< " into " << d_nodes[current.t2Id] << std::endl;
d_assertedEqualities.push_back(Equality(t2classId, t1classId));
if (!merge(node2, node1, triggers)) {
@@ -1222,7 +1263,7 @@ size_t EqualityEngine::getSize(TNode t)
void EqualityEngine::addTriggerTerm(TNode t, TheoryId tag)
{
- Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ", " << tag << ")" << std::endl;
+ Debug("equality::trigger") << "EqualityEngine::addTriggerTerm(" << t << ", " << tag << ")" << std::endl;
Assert(tag != THEORY_LAST);
if (d_done) {
@@ -1243,12 +1284,71 @@ void EqualityEngine::addTriggerTerm(TNode t, TheoryId tag)
// If the term already is in the equivalence class that a tagged representative, just notify
if (d_performNotify) {
EqualityNodeId triggerId = getTriggerTermSet(triggerSetRef).getTrigger(tag);
- if (!d_notify.eqNotifyTriggerTermEquality(tag, t, d_nodes[triggerId], true)) {
+ if (eqNodeId != triggerId && !d_notify.eqNotifyTriggerTermEquality(tag, t, d_nodes[triggerId], true)) {
d_done = true;
}
}
} else {
+ // Check for disequalities by going through the equivalence class looking for equalities in the
+ // uselists that have been asserted to false. All the representatives appearing on the other
+ // side of such disequalities, that have the tag on, are put in a set.
+ std::set<EqualityNodeId> disequalSet;
+ EqualityNodeId currentId = classId;
+ do {
+ // Current node
+ EqualityNode& currentNode = getEqualityNode(currentId);
+ // Go through the uselist and look for disequalities
+ UseListNodeId currentUseId = currentNode.getUseList();
+ while (!d_done && currentUseId != null_uselist_id) {
+ // Get the normalized equality
+ UseListNode& useNode = d_useListNodes[currentUseId];
+ EqualityNodeId funId = useNode.getApplicationId();
+ const FunctionApplication& fun = d_applications[useNode.getApplicationId()].original;
+ // Check for asserted disequalities
+ if (fun.isEquality && getEqualityNode(funId).getFind() == getEqualityNode(d_false).getFind()) {
+ // Get the other equality member
+ EqualityNodeId toCompare = fun.b;
+ if (toCompare == currentId) {
+ toCompare = fun.a;
+ }
+ // Representative of the other member
+ EqualityNodeId toCompareRep = getEqualityNode(toCompare).getFind();
+ Assert(toCompareRep != classId);
+ // Get the trigger set
+ TriggerTermSetRef toCompareTriggerSetRef = d_nodeIndividualTrigger[toCompareRep];
+ // Only notify if we're not both constants
+ if ((!d_isConstant[classId] || !d_isConstant[toCompareRep]) && toCompareTriggerSetRef != null_set_id) {
+ TriggerTermSet& toCompareTriggerSet = getTriggerTermSet(toCompareTriggerSetRef);
+ if (Theory::setContains(tag, toCompareTriggerSet.tags)) {
+ // Get the tag representative
+ EqualityNodeId tagRep = toCompareTriggerSet.getTrigger(tag);
+ // Propagate the disequality if not already done
+ if (!disequalSet.count(tagRep) && d_performNotify) {
+ // Mark as propagated
+ disequalSet.insert(tagRep);
+ // Store the propagation
+ d_deducedDisequalityReasons.push_back(EqualityPair(eqNodeId, currentId));
+ d_deducedDisequalityReasons.push_back(EqualityPair(toCompare, tagRep));
+ d_deducedDisequalityReasons.push_back(EqualityPair(funId, d_falseId));
+ storePropagatedDisequality(t, d_nodes[tagRep], 3);
+ // We don't check if it's been propagated already, as we need one per tag
+ if (d_performNotify) {
+ if (!d_notify.eqNotifyTriggerTermEquality(tag, t, d_nodes[tagRep], false)) {
+ d_done = true;
+ }
+ }
+ }
+ }
+ }
+ }
+ // Go to the next one in the use list
+ currentUseId = useNode.getNext();
+ }
+ // Next in equivalence class
+ currentId = currentNode.getNext();
+ } while (!d_done && currentId != classId);
+
// Setup the data for the new set
if (triggerSetRef != null_set_id) {
// Get the existing set
@@ -1322,7 +1422,7 @@ void EqualityEngine::getUseListTerms(TNode t, std::set<TNode>& output) {
// Get the current node
EqualityNode& currentNode = getEqualityNode(currentId);
// Go through the use-list
- UseListNodeId currentUseId = currentNode.getUseList<USE_LIST_FUNCTIONS>();
+ UseListNodeId currentUseId = currentNode.getUseList();
while (currentUseId != null_uselist_id) {
// Get the node of the use list
UseListNode& useNode = d_useListNodes[currentUseId];
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback