diff options
author | Dejan Jovanović <dejan.jovanovic@gmail.com> | 2012-06-06 06:12:40 +0000 |
---|---|---|
committer | Dejan Jovanović <dejan.jovanovic@gmail.com> | 2012-06-06 06:12:40 +0000 |
commit | fd9e22c4a2e57c3dfeda4de3842a3fb3ca4776ba (patch) | |
tree | 047e4d27f725e9157ed5bef5357d0d72560218ae /src/theory/uf/equality_engine.cpp | |
parent | 2799ae1cf57ed2b98387a1de1325bccd89bd2a30 (diff) |
Changes to the combination mechanism, lots of details. Not done yet, there are still the AUFBV wrong results, but it seems better.
http://church.cims.nyu.edu/regress-results/compare_jobs.php?job_id=4382&reference_id=4359&p=5
Diffstat (limited to 'src/theory/uf/equality_engine.cpp')
-rw-r--r-- | src/theory/uf/equality_engine.cpp | 170 |
1 files changed, 135 insertions, 35 deletions
diff --git a/src/theory/uf/equality_engine.cpp b/src/theory/uf/equality_engine.cpp index 5093e5a7b..4cd54a2bf 100644 --- a/src/theory/uf/equality_engine.cpp +++ b/src/theory/uf/equality_engine.cpp @@ -62,14 +62,14 @@ void EqualityEngine::init() { d_true = NodeManager::currentNM()->mkConst<bool>(true); d_false = NodeManager::currentNM()->mkConst<bool>(false); + d_triggerDatabaseAllocatedSize = 100000; + d_triggerDatabase = (char*) malloc(d_triggerDatabaseAllocatedSize); + addTerm(d_true); addTerm(d_false); d_trueId = getNodeId(d_true); d_falseId = getNodeId(d_false); - - d_triggerDatabaseAllocatedSize = 100000; - d_triggerDatabase = (char*) malloc(d_triggerDatabaseAllocatedSize); } EqualityEngine::~EqualityEngine() throw(AssertionException) { @@ -114,7 +114,8 @@ EqualityEngine::EqualityEngine(EqualityEngineNotify& notify, context::Context* c } void EqualityEngine::enqueue(const MergeCandidate& candidate) { - d_propagationQueue.push(candidate); + Debug("equality") << "EqualityEngine::enqueue(" << d_nodes[candidate.t1Id] << ", " << d_nodes[candidate.t2Id] << ", " << candidate.type << ")" << std::endl; + d_propagationQueue.push(candidate); } EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2, bool isEquality) { @@ -144,7 +145,15 @@ EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId if (isEquality && d_isConstant[t1ClassId] && d_isConstant[t2ClassId]) { if (t1ClassId != t2ClassId) { Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): got constants" << std::endl; + Assert(d_nodes[funId].getKind() == kind::EQUAL); enqueue(MergeCandidate(funId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null())); + // Also enqueue the symmetric one + TNode eq = d_nodes[funId]; + Node symmetricEq = eq[1].eqNode(eq[0]); + if (hasTerm(symmetricEq)) { + EqualityNodeId symmFunId = getNodeId(symmetricEq); + enqueue(MergeCandidate(symmFunId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null())); + } } } } else { @@ -154,8 +163,8 @@ EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId } // Add to the use lists - d_equalityNodes[t1ClassId].usedIn<USE_LIST_FUNCTIONS>(funId, d_useListNodes); - d_equalityNodes[t2ClassId].usedIn<USE_LIST_FUNCTIONS>(funId, d_useListNodes); + d_equalityNodes[t1].usedIn(funId, d_useListNodes); + d_equalityNodes[t2].usedIn(funId, d_useListNodes); // Return the new id Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ") => " << funId << std::endl; @@ -238,6 +247,20 @@ void EqualityEngine::addTerm(TNode t) { // We set this here as this only applies to actual terms, not the // intermediate application terms d_isBoolean[result] = true; + } else if (t.isConst()) { + // Non-Boolean constants are trigger terms for all tags + EqualityNodeId tId = getNodeId(t); + d_newSetTags = 0; + d_newSetTriggersSize = THEORY_LAST; + for (TheoryId currentTheory = THEORY_FIRST; currentTheory != THEORY_LAST; ++ currentTheory) { + d_newSetTags = Theory::setInsert(currentTheory, d_newSetTags); + d_newSetTriggers[currentTheory] = tId; + } + // Add it to the list for backtracking + d_triggerTermSetUpdates.push_back(TriggerSetUpdate(tId, null_set_id)); + d_triggerTermSetUpdatesSize = d_triggerTermSetUpdatesSize + 1; + // Mark the the new set as a trigger + d_nodeIndividualTrigger[tId] = newTriggerTermSet(); } propagate(); @@ -319,15 +342,20 @@ void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason) { assertEqualityInternal(eq, d_false, reason); propagate(); assertEqualityInternal(eq[1].eqNode(eq[0]), d_false, reason); - propagate(); + propagate(); if (d_done) { return; } - // If we are adding a disequality, notify of the shared term representatives + // If both have constant representatives, we don't notify anyone EqualityNodeId a = getNodeId(eq[0]); EqualityNodeId b = getNodeId(eq[1]); + if (isConstant(a) && isConstant(b)) { + return; + } + + // If we are adding a disequality, notify of the shared term representatives EqualityNodeId eqId = getNodeId(eq); TriggerTermSetRef aTriggerRef = d_nodeIndividualTrigger[a]; TriggerTermSetRef bTriggerRef = d_nodeIndividualTrigger[b]; @@ -356,6 +384,7 @@ void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason) { d_deducedDisequalityReasons.push_back(EqualityPair(bSharedId, b)); d_deducedDisequalityReasons.push_back(EqualityPair(eqId, d_falseId)); storePropagatedDisequality(d_nodes[aSharedId], d_nodes[bSharedId], 3); + // We notify even if the it's already been sent (they are not // disequal at assertion, and we need to notify for each tag) if (!d_notify.eqNotifyTriggerTermEquality(aTag, d_nodes[aSharedId], d_nodes[bSharedId], false)) { @@ -383,7 +412,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl; Assert(triggersFired.empty()); - + ++ d_stats.mergesCount; EqualityNodeId class1Id = class1.getFind(); @@ -391,6 +420,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect // Check for constant merges bool isConstant = d_isConstant[class1Id]; + Assert(isConstant || !d_isConstant[class2Id]); // Update class2 representative information Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating class " << class2Id << std::endl; @@ -438,13 +468,13 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of node " << currentId << std::endl; // Go through the uselist and check for congruences - UseListNodeId currentUseId = currentNode.getUseList<USE_LIST_FUNCTIONS>(); + UseListNodeId currentUseId = currentNode.getUseList(); while (currentUseId != null_uselist_id) { // Get the node of the use list UseListNode& useNode = d_useListNodes[currentUseId]; // Get the function application EqualityNodeId funId = useNode.getApplicationId(); - Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << currentId << " in " << d_nodes[funId] << std::endl; + Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << d_nodes[currentId] << " in " << d_nodes[funId] << std::endl; const FunctionApplication& fun = d_applications[useNode.getApplicationId()].normalized; // Check if there is an application with find arguments EqualityNodeId aNormalized = getEqualityNode(fun.a).getFind(); @@ -460,11 +490,16 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect // There is no representative, so we can add one, we remove this when backtracking storeApplicationLookup(funNormalized, funId); // Now, if we're constant and it's an equality, check if the other guy is also a constant - if (isConstant && funNormalized.isEquality) { - if (d_isConstant[funNormalized.a] && d_isConstant[funNormalized.b]) { - // both constants - if (funNormalized.a != funNormalized.b) { - enqueue(MergeCandidate(funId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null())); + if (funNormalized.isEquality) { + if (d_isConstant[aNormalized] && d_isConstant[bNormalized] && aNormalized != bNormalized) { + Assert(d_nodes[funId].getKind() == kind::EQUAL); + enqueue(MergeCandidate(funId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null())); + // Also enqueue the symmetric one + TNode eq = d_nodes[funId]; + Node symmetricEq = eq[1].eqNode(eq[0]); + if (hasTerm(symmetricEq)) { + EqualityNodeId symmFunId = getNodeId(symmetricEq); + enqueue(MergeCandidate(symmFunId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null())); } } } @@ -482,6 +517,17 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect // Now merge the lists class1.merge<true>(class2); + // Notify of the constants merge + bool constantMerge = false; + if (isConstant && d_isConstant[class2Id]) { + constantMerge = true; + if (d_performNotify) { + if (!d_notify.eqNotifyConstantTermMerge(d_nodes[class1Id], d_nodes[class2Id])) { + return false; + } + } + } + // Go through the triggers and store the dis-equalities unsigned i = 0, j = 0; for (; i < triggersFired.size();) { @@ -543,7 +589,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect // copy tag1 EqualityNodeId tag1id = d_newSetTriggers[d_newSetTriggersSize++] = class1triggers.triggers[i1++]; // since they are both tagged notify of merge - if (d_performNotify) { + if (d_performNotify && !constantMerge) { EqualityNodeId tag2id = class2triggers.triggers[i2++]; if (!d_notify.eqNotifyTriggerTermEquality(tag1, d_nodes[tag1id], d_nodes[tag2id], true)) { return false; @@ -566,15 +612,6 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect } } - // Notify of the constants merge - if (isConstant && d_isConstant[class2Id]) { - if (d_performNotify) { - if (!d_notify.eqNotifyConstantTermMerge(d_nodes[class1Id], d_nodes[class2Id])) { - return false; - } - } - } - // Everything fine return true; } @@ -680,12 +717,12 @@ void EqualityEngine::backtrack() { Debug("equality") << "EqualityEngine::backtrack(): removing node " << d_nodes[i] << std::endl; d_nodeIds.erase(d_nodes[i]); - const FunctionApplication& app = d_applications[i].normalized; + const FunctionApplication& app = d_applications[i].original; if (app.isApplication()) { // Remove b from use-list - getEqualityNode(app.b).removeTopFromUseList<USE_LIST_FUNCTIONS>(d_useListNodes); + getEqualityNode(app.b).removeTopFromUseList(d_useListNodes); // Remove a from use-list - getEqualityNode(app.a).removeTopFromUseList<USE_LIST_FUNCTIONS>(d_useListNodes); + getEqualityNode(app.a).removeTopFromUseList(d_useListNodes); } } @@ -818,8 +855,10 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, st // Go through the equality edges of this node EqualityEdgeId currentEdge = d_equalityGraph[currentNode]; - Debug("equality") << "EqualityEngine::getExplanation(): edgesId = " << currentEdge << std::endl; - Debug("equality") << "EqualityEngine::getExplanation(): edges = " << edgesToString(currentEdge) << std::endl; + if (Debug.isOn("equality")) { + Debug("equality") << "EqualityEngine::getExplanation(): edgesId = " << currentEdge << std::endl; + Debug("equality") << "EqualityEngine::getExplanation(): edges = " << edgesToString(currentEdge) << std::endl; + } while (currentEdge != null_edge) { // Get the edge @@ -871,8 +910,10 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, st Debug("equality") << push; // Explain why a is a constant + Assert(isConstant(eq.a)); getExplanation(eq.a, getEqualityNode(eq.a).getFind(), equalities); // Explain why b is a constant + Assert(isConstant(eq.b)); getExplanation(eq.b, getEqualityNode(eq.b).getFind(), equalities); Debug("equality") << pop; @@ -1062,7 +1103,7 @@ void EqualityEngine::propagate() { // Depending on the merge preference (such as size, or being a constant), merge them std::vector<TriggerId> triggers; - if (node2.getSize() > node1.getSize() || d_isConstant[t2classId]) { + if ((node2.getSize() > node1.getSize() && !d_isConstant[t1classId]) || d_isConstant[t2classId]) { Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t1Id]<< " into " << d_nodes[current.t2Id] << std::endl; d_assertedEqualities.push_back(Equality(t2classId, t1classId)); if (!merge(node2, node1, triggers)) { @@ -1222,7 +1263,7 @@ size_t EqualityEngine::getSize(TNode t) void EqualityEngine::addTriggerTerm(TNode t, TheoryId tag) { - Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ", " << tag << ")" << std::endl; + Debug("equality::trigger") << "EqualityEngine::addTriggerTerm(" << t << ", " << tag << ")" << std::endl; Assert(tag != THEORY_LAST); if (d_done) { @@ -1243,12 +1284,71 @@ void EqualityEngine::addTriggerTerm(TNode t, TheoryId tag) // If the term already is in the equivalence class that a tagged representative, just notify if (d_performNotify) { EqualityNodeId triggerId = getTriggerTermSet(triggerSetRef).getTrigger(tag); - if (!d_notify.eqNotifyTriggerTermEquality(tag, t, d_nodes[triggerId], true)) { + if (eqNodeId != triggerId && !d_notify.eqNotifyTriggerTermEquality(tag, t, d_nodes[triggerId], true)) { d_done = true; } } } else { + // Check for disequalities by going through the equivalence class looking for equalities in the + // uselists that have been asserted to false. All the representatives appearing on the other + // side of such disequalities, that have the tag on, are put in a set. + std::set<EqualityNodeId> disequalSet; + EqualityNodeId currentId = classId; + do { + // Current node + EqualityNode& currentNode = getEqualityNode(currentId); + // Go through the uselist and look for disequalities + UseListNodeId currentUseId = currentNode.getUseList(); + while (!d_done && currentUseId != null_uselist_id) { + // Get the normalized equality + UseListNode& useNode = d_useListNodes[currentUseId]; + EqualityNodeId funId = useNode.getApplicationId(); + const FunctionApplication& fun = d_applications[useNode.getApplicationId()].original; + // Check for asserted disequalities + if (fun.isEquality && getEqualityNode(funId).getFind() == getEqualityNode(d_false).getFind()) { + // Get the other equality member + EqualityNodeId toCompare = fun.b; + if (toCompare == currentId) { + toCompare = fun.a; + } + // Representative of the other member + EqualityNodeId toCompareRep = getEqualityNode(toCompare).getFind(); + Assert(toCompareRep != classId); + // Get the trigger set + TriggerTermSetRef toCompareTriggerSetRef = d_nodeIndividualTrigger[toCompareRep]; + // Only notify if we're not both constants + if ((!d_isConstant[classId] || !d_isConstant[toCompareRep]) && toCompareTriggerSetRef != null_set_id) { + TriggerTermSet& toCompareTriggerSet = getTriggerTermSet(toCompareTriggerSetRef); + if (Theory::setContains(tag, toCompareTriggerSet.tags)) { + // Get the tag representative + EqualityNodeId tagRep = toCompareTriggerSet.getTrigger(tag); + // Propagate the disequality if not already done + if (!disequalSet.count(tagRep) && d_performNotify) { + // Mark as propagated + disequalSet.insert(tagRep); + // Store the propagation + d_deducedDisequalityReasons.push_back(EqualityPair(eqNodeId, currentId)); + d_deducedDisequalityReasons.push_back(EqualityPair(toCompare, tagRep)); + d_deducedDisequalityReasons.push_back(EqualityPair(funId, d_falseId)); + storePropagatedDisequality(t, d_nodes[tagRep], 3); + // We don't check if it's been propagated already, as we need one per tag + if (d_performNotify) { + if (!d_notify.eqNotifyTriggerTermEquality(tag, t, d_nodes[tagRep], false)) { + d_done = true; + } + } + } + } + } + } + // Go to the next one in the use list + currentUseId = useNode.getNext(); + } + // Next in equivalence class + currentId = currentNode.getNext(); + } while (!d_done && currentId != classId); + // Setup the data for the new set if (triggerSetRef != null_set_id) { // Get the existing set @@ -1322,7 +1422,7 @@ void EqualityEngine::getUseListTerms(TNode t, std::set<TNode>& output) { // Get the current node EqualityNode& currentNode = getEqualityNode(currentId); // Go through the use-list - UseListNodeId currentUseId = currentNode.getUseList<USE_LIST_FUNCTIONS>(); + UseListNodeId currentUseId = currentNode.getUseList(); while (currentUseId != null_uselist_id) { // Get the node of the use list UseListNode& useNode = d_useListNodes[currentUseId]; |