diff options
Diffstat (limited to 'src/theory/shared_terms_database.cpp')
-rw-r--r-- | src/theory/shared_terms_database.cpp | 64 |
1 files changed, 44 insertions, 20 deletions
diff --git a/src/theory/shared_terms_database.cpp b/src/theory/shared_terms_database.cpp index 2f9ad74e0..92c66e83b 100644 --- a/src/theory/shared_terms_database.cpp +++ b/src/theory/shared_terms_database.cpp @@ -32,10 +32,11 @@ SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine, d_alreadyNotifiedMap(context), d_registeredEqualities(context), d_EENotify(*this), - d_equalityEngine(d_EENotify, context, "SharedTermsDatabase", true), d_theoryEngine(theoryEngine), d_inConflict(context, false), - d_conflictPolarity() { + d_conflictPolarity(), + d_equalityEngine(nullptr) +{ smtStatisticsRegistry()->registerStat(&d_statSharedTerms); } @@ -46,7 +47,7 @@ SharedTermsDatabase::~SharedTermsDatabase() void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee) { - // TODO (project #39): dynamic allocation of equality engine here + d_equalityEngine = ee; } bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi) @@ -57,8 +58,9 @@ bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi) } void SharedTermsDatabase::addEqualityToPropagate(TNode equality) { + Assert(d_equalityEngine != nullptr); d_registeredEqualities.insert(equality); - d_equalityEngine.addTriggerPredicate(equality); + d_equalityEngine->addTriggerPredicate(equality); checkForConflict(); } @@ -183,12 +185,18 @@ void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories) d_alreadyNotifiedMap[term] = TheoryIdSetUtil::setUnion(newlyNotified, alreadyNotified); + if (d_equalityEngine == nullptr) + { + // if we are not assigned an equality engine, there is nothing to do + return; + } + // Mark the shared terms in the equality engine theory::TheoryId currentTheory; while ((currentTheory = TheoryIdSetUtil::setPop(newlyNotified)) != THEORY_LAST) { - d_equalityEngine.addTriggerTerm(term, currentTheory); + d_equalityEngine->addTriggerTerm(term, currentTheory); } // Check for any conflits @@ -196,32 +204,42 @@ void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories) } bool SharedTermsDatabase::areEqual(TNode a, TNode b) const { - if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) { - return d_equalityEngine.areEqual(a,b); + Assert(d_equalityEngine != nullptr); + if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)) + { + return d_equalityEngine->areEqual(a, b); } else { - Assert(d_equalityEngine.hasTerm(a) || a.isConst()); - Assert(d_equalityEngine.hasTerm(b) || b.isConst()); + Assert(d_equalityEngine->hasTerm(a) || a.isConst()); + Assert(d_equalityEngine->hasTerm(b) || b.isConst()); // since one (or both) of them is a constant, and the other is in the equality engine, they are not same return false; } } bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const { - if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) { - return d_equalityEngine.areDisequal(a,b,false); + Assert(d_equalityEngine != nullptr); + if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)) + { + return d_equalityEngine->areDisequal(a, b, false); } else { - Assert(d_equalityEngine.hasTerm(a) || a.isConst()); - Assert(d_equalityEngine.hasTerm(b) || b.isConst()); + Assert(d_equalityEngine->hasTerm(a) || a.isConst()); + Assert(d_equalityEngine->hasTerm(b) || b.isConst()); // one (or both) are in the equality engine return false; } } +theory::eq::EqualityEngine* SharedTermsDatabase::getEqualityEngine() +{ + return d_equalityEngine; +} + void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason) { + Assert(d_equalityEngine != nullptr); Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl; // Add it to the equality engine - d_equalityEngine.assertEquality(equality, polarity, reason); + d_equalityEngine->assertEquality(equality, polarity, reason); // Check for conflict checkForConflict(); } @@ -258,10 +276,12 @@ static Node mkAnd(const std::vector<TNode>& conjunctions) { } void SharedTermsDatabase::checkForConflict() { + Assert(d_equalityEngine != nullptr); if (d_inConflict) { d_inConflict = false; std::vector<TNode> assumptions; - d_equalityEngine.explainEquality(d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions); + d_equalityEngine->explainEquality( + d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions); Node conflict = mkAnd(assumptions); TrustNode tconf = TrustNode::mkTrustConflict(conflict); d_theoryEngine->conflict(tconf, THEORY_BUILTIN); @@ -270,22 +290,26 @@ void SharedTermsDatabase::checkForConflict() { } bool SharedTermsDatabase::isKnown(TNode literal) const { + Assert(d_equalityEngine != nullptr); bool polarity = literal.getKind() != kind::NOT; TNode equality = polarity ? literal : literal[0]; if (polarity) { - return d_equalityEngine.areEqual(equality[0], equality[1]); + return d_equalityEngine->areEqual(equality[0], equality[1]); } else { - return d_equalityEngine.areDisequal(equality[0], equality[1], false); + return d_equalityEngine->areDisequal(equality[0], equality[1], false); } } -Node SharedTermsDatabase::explain(TNode literal) const { +TrustNode SharedTermsDatabase::explain(TNode literal) const +{ + Assert(d_equalityEngine != nullptr); bool polarity = literal.getKind() != kind::NOT; TNode atom = polarity ? literal : literal[0]; Assert(atom.getKind() == kind::EQUAL); std::vector<TNode> assumptions; - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); - return mkAnd(assumptions); + d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions); + Node exp = mkAnd(assumptions); + return TrustNode::mkTrustPropExp(literal, exp, nullptr); } } /* namespace CVC4 */ |