diff options
37 files changed, 1134 insertions, 807 deletions
diff --git a/src/proof/theory_proof.cpp b/src/proof/theory_proof.cpp index 3103557c8..b47fd6a1e 100644 --- a/src/proof/theory_proof.cpp +++ b/src/proof/theory_proof.cpp @@ -1093,6 +1093,12 @@ void TheoryProof::printTheoryLemmaProof(std::vector<Expr>& lemma, InternalError() << "can't generate theory-proof for " << ProofManager::currentPM()->getLogic(); } + // must perform initialization on the theory + if (th != nullptr) + { + // finish init, standalone version + th->finishInitStandalone(); + } Debug("pf::tp") << "TheoryProof::printTheoryLemmaProof - calling th->ProduceProofs()" << std::endl; th->produceProofs(); diff --git a/src/theory/arith/congruence_manager.cpp b/src/theory/arith/congruence_manager.cpp index ab3b485a8..a70339c01 100644 --- a/src/theory/arith/congruence_manager.cpp +++ b/src/theory/arith/congruence_manager.cpp @@ -42,16 +42,29 @@ ArithCongruenceManager::ArithCongruenceManager( d_constraintDatabase(cd), d_setupLiteral(setup), d_avariables(avars), - d_ee(d_notify, c, "theory::arith::ArithCongruenceManager", true) + d_ee(nullptr) { - d_ee.addFunctionKind(kind::NONLINEAR_MULT); - d_ee.addFunctionKind(kind::EXPONENTIAL); - d_ee.addFunctionKind(kind::SINE); - d_ee.addFunctionKind(kind::IAND); } ArithCongruenceManager::~ArithCongruenceManager() {} +bool ArithCongruenceManager::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::arith::ArithCongruenceManager"; + return true; +} + +void ArithCongruenceManager::finishInit(eq::EqualityEngine* ee) +{ + Assert(ee != nullptr); + d_ee = ee; + d_ee->addFunctionKind(kind::NONLINEAR_MULT); + d_ee->addFunctionKind(kind::EXPONENTIAL); + d_ee->addFunctionKind(kind::SINE); + d_ee->addFunctionKind(kind::IAND); +} + ArithCongruenceManager::Statistics::Statistics(): d_watchedVariables("theory::arith::congruence::watchedVariables", 0), d_watchedVariableIsZero("theory::arith::congruence::watchedVariableIsZero", 0), @@ -141,10 +154,6 @@ bool ArithCongruenceManager::canExplain(TNode n) const { return d_explanationMap.find(n) != d_explanationMap.end(); } -void ArithCongruenceManager::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_ee.setMasterEqualityEngine(eq); -} - Node ArithCongruenceManager::externalToInternal(TNode n) const{ Assert(canExplain(n)); ExplainMap::const_iterator iter = d_explanationMap.find(n); @@ -320,9 +329,9 @@ bool ArithCongruenceManager::propagate(TNode x){ void ArithCongruenceManager::explain(TNode literal, std::vector<TNode>& assumptions) { if (literal.getKind() != kind::NOT) { - d_ee.explainEquality(literal[0], literal[1], true, assumptions); + d_ee->explainEquality(literal[0], literal[1], true, assumptions); } else { - d_ee.explainEquality(literal[0][0], literal[0][1], false, assumptions); + d_ee->explainEquality(literal[0][0], literal[0][1], false, assumptions); } } @@ -392,9 +401,9 @@ void ArithCongruenceManager::assertionToEqualityEngine(bool isEquality, ArithVar Trace("arith-ee") << "Assert " << eq << ", pol " << isEquality << ", reason " << reason << std::endl; if(isEquality){ - d_ee.assertEquality(eq, true, reason); + d_ee->assertEquality(eq, true, reason); }else{ - d_ee.assertEquality(eq, false, reason); + d_ee->assertEquality(eq, false, reason); } } @@ -417,7 +426,7 @@ void ArithCongruenceManager::equalsConstant(ConstraintCP c){ d_keepAlive.push_back(reason); Trace("arith-ee") << "Assert equalsConstant " << eq << ", reason " << reason << std::endl; - d_ee.assertEquality(eq, true, reason); + d_ee->assertEquality(eq, true, reason); } void ArithCongruenceManager::equalsConstant(ConstraintCP lb, ConstraintCP ub){ @@ -441,11 +450,11 @@ void ArithCongruenceManager::equalsConstant(ConstraintCP lb, ConstraintCP ub){ d_keepAlive.push_back(reason); Trace("arith-ee") << "Assert equalsConstant2 " << eq << ", reason " << reason << std::endl; - d_ee.assertEquality(eq, true, reason); + d_ee->assertEquality(eq, true, reason); } void ArithCongruenceManager::addSharedTerm(Node x){ - d_ee.addTriggerTerm(x, THEORY_ARITH); + d_ee->addTriggerTerm(x, THEORY_ARITH); } }/* CVC4::theory::arith namespace */ diff --git a/src/theory/arith/congruence_manager.h b/src/theory/arith/congruence_manager.h index aeb72ec94..f3b5641b4 100644 --- a/src/theory/arith/congruence_manager.h +++ b/src/theory/arith/congruence_manager.h @@ -95,7 +95,8 @@ private: const ArithVariables& d_avariables; - eq::EqualityEngine d_ee; + /** The equality engine being used by this class */ + eq::EqualityEngine* d_ee; void raiseConflict(Node conflict); public: @@ -108,8 +109,6 @@ public: bool canExplain(TNode n) const; - void setMasterEqualityEngine(eq::EqualityEngine* eq); - private: Node externalToInternal(TNode n) const; @@ -138,6 +137,19 @@ public: ArithCongruenceManager(context::Context* satContext, ConstraintDatabase&, SetupLiteralCallBack, const ArithVariables&, RaiseEqualityEngineConflict raiseConflict); ~ArithCongruenceManager(); + //--------------------------------- initialization + /** + * Returns true if we need an equality engine, see + * Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi); + /** + * Finish initialize. This class is instructed by TheoryArithPrivate to use + * the equality engine ee. + */ + void finishInit(eq::EqualityEngine* ee); + //--------------------------------- end initialization + Node explain(TNode literal); void explain(TNode lit, NodeBuilder<>& out); @@ -166,10 +178,8 @@ public: void addSharedTerm(Node x); - - eq::EqualityEngine * getEqualityEngine() { return &d_ee; } -private: + private: class Statistics { public: IntStat d_watchedVariables; diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index bc6e18a83..b95b5e243 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -56,10 +56,10 @@ TheoryRewriter* TheoryArith::getTheoryRewriter() return d_internal->getTheoryRewriter(); } -void TheoryArith::preRegisterTerm(TNode n){ - d_internal->preRegisterTerm(n); +bool TheoryArith::needsEqualityEngine(EeSetupInfo& esi) +{ + return d_internal->needsEqualityEngine(esi); } - void TheoryArith::finishInit() { if (getLogicInfo().isTheoryEnabled(THEORY_ARITH) @@ -72,17 +72,17 @@ void TheoryArith::finishInit() d_valuation.setUnevaluatedKind(kind::SINE); d_valuation.setUnevaluatedKind(kind::PI); } + // finish initialize internally + d_internal->finishInit(); } +void TheoryArith::preRegisterTerm(TNode n) { d_internal->preRegisterTerm(n); } + TrustNode TheoryArith::expandDefinition(Node node) { return d_internal->expandDefinition(node); } -void TheoryArith::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_internal->setMasterEqualityEngine(eq); -} - void TheoryArith::addSharedTerm(TNode n){ d_internal->addSharedTerm(n); } diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index 30de7bbad..ad3b91b07 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -55,19 +55,28 @@ class TheoryArith : public Theory { ProofNodeManager* pnm = nullptr); virtual ~TheoryArith(); + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if this theory needs an equality engine, which is assigned + * to it (d_equalityEngine) by the equality engine manager during + * TheoryEngine::finishInit, prior to calling finishInit for this theory. + * If this method returns true, it stores instructions for the notifications + * this Theory wishes to receive from its equality engine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ + void finishInit() override; + //--------------------------------- end initialization /** * Does non-context dependent setup for a node connected to a theory. */ void preRegisterTerm(TNode n) override; - void finishInit() override; - TrustNode expandDefinition(Node node) override; - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - void check(Effort e) override; bool needsCheckLastEffort() override; void propagate(Effort e) override; diff --git a/src/theory/arith/theory_arith_private.cpp b/src/theory/arith/theory_arith_private.cpp index 6f47ffb0e..8ca99d369 100644 --- a/src/theory/arith/theory_arith_private.cpp +++ b/src/theory/arith/theory_arith_private.cpp @@ -134,7 +134,7 @@ TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing, d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), d_attemptSolSimplex( d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), - d_nonlinearExtension(NULL), + d_nonlinearExtension(nullptr), d_pass1SDP(NULL), d_otherSDP(NULL), d_lastContextIntegerAttempted(c, -1), @@ -159,12 +159,6 @@ TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing, d_statistics(), d_opElim(pnm, logicInfo) { - // only need to create if non-linear logic - if (logicInfo.isTheoryEnabled(THEORY_ARITH) && !logicInfo.isLinear()) - { - d_nonlinearExtension = new nl::NonlinearExtension( - containing, d_congruenceManager.getEqualityEngine()); - } } TheoryArithPrivate::~TheoryArithPrivate(){ @@ -173,6 +167,24 @@ TheoryArithPrivate::~TheoryArithPrivate(){ if(d_nonlinearExtension != NULL) { delete d_nonlinearExtension; } } +TheoryRewriter* TheoryArithPrivate::getTheoryRewriter() { return &d_rewriter; } +bool TheoryArithPrivate::needsEqualityEngine(EeSetupInfo& esi) +{ + return d_congruenceManager.needsEqualityEngine(esi); +} +void TheoryArithPrivate::finishInit() +{ + eq::EqualityEngine* ee = d_containing.getEqualityEngine(); + Assert(ee != nullptr); + d_congruenceManager.finishInit(ee); + const LogicInfo& logicInfo = getLogicInfo(); + // only need to create nonlinear extension if non-linear logic + if (logicInfo.isTheoryEnabled(THEORY_ARITH) && !logicInfo.isLinear()) + { + d_nonlinearExtension = new nl::NonlinearExtension(d_containing, ee); + } +} + static bool contains(const ConstraintCPVec& v, ConstraintP con){ for(unsigned i = 0, N = v.size(); i < N; ++i){ if(v[i] == con){ @@ -227,10 +239,6 @@ static void resolve(ConstraintCPVec& buf, ConstraintP c, const ConstraintCPVec& // return safeConstructNary(nb); } -void TheoryArithPrivate::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_congruenceManager.setMasterEqualityEngine(eq); -} - TheoryArithPrivate::ModelException::ModelException(TNode n, const char* msg) { stringstream ss; diff --git a/src/theory/arith/theory_arith_private.h b/src/theory/arith/theory_arith_private.h index 42ec7f47b..4c4aedf00 100644 --- a/src/theory/arith/theory_arith_private.h +++ b/src/theory/arith/theory_arith_private.h @@ -427,7 +427,17 @@ private: ProofNodeManager* pnm); ~TheoryArithPrivate(); - TheoryRewriter* getTheoryRewriter() { return &d_rewriter; } + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter(); + /** + * Returns true if we need an equality engine, see + * Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi); + /** finish initialize */ + void finishInit(); + //--------------------------------- end initialization /** * Does non-context dependent setup for a node connected to a theory. @@ -435,8 +445,6 @@ private: void preRegisterTerm(TNode n); TrustNode expandDefinition(Node node); - void setMasterEqualityEngine(eq::EqualityEngine* eq); - void check(Theory::Effort e); bool needsCheckLastEffort(); void propagate(Theory::Effort e); diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 245da617b..85759b75f 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -88,7 +88,6 @@ TheoryArrays::TheoryArrays(context::Context* c, d_isPreRegistered(c), d_mayEqualEqualityEngine(c, name + "theory::arrays::mayEqual", true), d_notify(*this), - d_equalityEngine(d_notify, c, name + "theory::arrays", true), d_conflict(c, false), d_backtracker(c), d_infoMap(c, &d_backtracker, name), @@ -112,7 +111,7 @@ TheoryArrays::TheoryArrays(context::Context* c, d_readTableContext(new context::Context()), d_arrayMerges(c), d_inCheckModel(false), - d_proofReconstruction(&d_equalityEngine), + d_proofReconstruction(nullptr), d_dstrat(new TheoryArraysDecisionStrategy(this)), d_dstratInit(false) { @@ -133,27 +132,6 @@ TheoryArrays::TheoryArrays(context::Context* c, // The preprocessing congruence kinds d_ppEqualityEngine.addFunctionKind(kind::SELECT); d_ppEqualityEngine.addFunctionKind(kind::STORE); - - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::SELECT); - if (d_ccStore) { - d_equalityEngine.addFunctionKind(kind::STORE); - } - if (d_useArrTable) { - d_equalityEngine.addFunctionKind(kind::ARR_TABLE_FUN); - } - - d_reasonRow = d_equalityEngine.getFreshMergeReasonType(); - d_reasonRow1 = d_equalityEngine.getFreshMergeReasonType(); - d_reasonExt = d_equalityEngine.getFreshMergeReasonType(); - - d_proofReconstruction.setRowMergeTag(d_reasonRow); - d_proofReconstruction.setRow1MergeTag(d_reasonRow1); - d_proofReconstruction.setExtMergeTag(d_reasonExt); - - d_equalityEngine.addPathReconstructionTrigger(d_reasonRow, &d_proofReconstruction); - d_equalityEngine.addPathReconstructionTrigger(d_reasonRow1, &d_proofReconstruction); - d_equalityEngine.addPathReconstructionTrigger(d_reasonExt, &d_proofReconstruction); } TheoryArrays::~TheoryArrays() { @@ -179,8 +157,45 @@ TheoryArrays::~TheoryArrays() { smtStatisticsRegistry()->unregisterStat(&d_numSetModelValConflicts); } -void TheoryArrays::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); +TheoryRewriter* TheoryArrays::getTheoryRewriter() { return &d_rewriter; } + +bool TheoryArrays::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = d_instanceName + "theory::arrays::ee"; + return true; +} + +void TheoryArrays::finishInit() +{ + Assert(d_equalityEngine != nullptr); + + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::SELECT); + if (d_ccStore) + { + d_equalityEngine->addFunctionKind(kind::STORE); + } + if (d_useArrTable) + { + d_equalityEngine->addFunctionKind(kind::ARR_TABLE_FUN); + } + + d_proofReconstruction.reset(new ArrayProofReconstruction(d_equalityEngine)); + d_reasonRow = d_equalityEngine->getFreshMergeReasonType(); + d_reasonRow1 = d_equalityEngine->getFreshMergeReasonType(); + d_reasonExt = d_equalityEngine->getFreshMergeReasonType(); + + d_proofReconstruction->setRowMergeTag(d_reasonRow); + d_proofReconstruction->setRow1MergeTag(d_reasonRow1); + d_proofReconstruction->setExtMergeTag(d_reasonExt); + + d_equalityEngine->addPathReconstructionTrigger(d_reasonRow, + d_proofReconstruction.get()); + d_equalityEngine->addPathReconstructionTrigger(d_reasonRow1, + d_proofReconstruction.get()); + d_equalityEngine->addPathReconstructionTrigger(d_reasonExt, + d_proofReconstruction.get()); } ///////////////////////////////////////////////////////////////////////////// @@ -427,9 +442,10 @@ void TheoryArrays::explain(TNode literal, std::vector<TNode>& assumptions, //eq::EqProof * eqp = new eq::EqProof; // eq::EqProof * eqp = NULL; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions, proof); + d_equalityEngine->explainEquality( + atom[0], atom[1], polarity, assumptions, proof); } else { - d_equalityEngine.explainPredicate(atom, polarity, assumptions, proof); + d_equalityEngine->explainPredicate(atom, polarity, assumptions, proof); } if (Debug.isOn("pf::array")) { @@ -469,7 +485,8 @@ TNode TheoryArrays::weakEquivGetRepIndex(TNode node, TNode index) { return node; } index2 = d_infoMap.getWeakEquivIndex(node); - if (index2.isNull() || !d_equalityEngine.areEqual(index, index2)) { + if (index2.isNull() || !d_equalityEngine->areEqual(index, index2)) + { node = pointer; } else { @@ -493,7 +510,8 @@ void TheoryArrays::visitAllLeaves(TNode reason, vector<TNode>& conjunctions) { conjunctions.push_back(reason); break; case kind::EQUAL: - d_equalityEngine.explainEquality(reason[0], reason[1], true, conjunctions); + d_equalityEngine->explainEquality( + reason[0], reason[1], true, conjunctions); break; default: Unreachable(); @@ -511,10 +529,11 @@ void TheoryArrays::weakEquivBuildCond(TNode node, TNode index, vector<TNode>& co index2 = d_infoMap.getWeakEquivIndex(node); if (index2.isNull()) { // Null index means these two nodes became equal: explain the equality. - d_equalityEngine.explainEquality(node, pointer, true, conjunctions); + d_equalityEngine->explainEquality(node, pointer, true, conjunctions); node = pointer; } - else if (!d_equalityEngine.areEqual(index, index2)) { + else if (!d_equalityEngine->areEqual(index, index2)) + { // If indices are not equal in current context, need to add that to the lemma. Node reason = index.eqNode(index2).notNode(); d_permRef.push_back(reason); @@ -556,7 +575,8 @@ void TheoryArrays::weakEquivMakeRepIndex(TNode node) { TNode index2 = d_infoMap.getWeakEquivIndex(secondary); Node reason; TNode next; - while (index2.isNull() || !d_equalityEngine.areEqual(index, index2)) { + while (index2.isNull() || !d_equalityEngine->areEqual(index, index2)) + { next = d_infoMap.getWeakEquivPointer(secondary); d_infoMap.setWeakEquivSecondary(node, next); reason = d_infoMap.getWeakEquivSecondaryReason(node); @@ -590,13 +610,13 @@ void TheoryArrays::weakEquivAddSecondary(TNode index, TNode arrayFrom, TNode arr TNode pointer, indexRep; if (!index.isNull()) { index_trail.push_back(index); - marked.insert(d_equalityEngine.getRepresentative(index)); + marked.insert(d_equalityEngine->getRepresentative(index)); } while (arrayFrom != arrayTo) { index = d_infoMap.getWeakEquivIndex(arrayFrom); pointer = d_infoMap.getWeakEquivPointer(arrayFrom); if (!index.isNull()) { - indexRep = d_equalityEngine.getRepresentative(index); + indexRep = d_equalityEngine->getRepresentative(index); if (marked.find(indexRep) == marked.end() && weakEquivGetRepIndex(arrayFrom, index) != arrayTo) { weakEquivMakeRepIndex(arrayFrom); d_infoMap.setWeakEquivSecondary(arrayFrom, arrayTo); @@ -639,7 +659,7 @@ void TheoryArrays::checkWeakEquiv(bool arraysMerged) { || !secondary.isNull()); if (!pointer.isNull()) { if (index.isNull()) { - Assert(d_equalityEngine.areEqual(n, pointer)); + Assert(d_equalityEngine->areEqual(n, pointer)); } else { Assert( @@ -677,16 +697,17 @@ void TheoryArrays::preRegisterTermInternal(TNode node) case kind::EQUAL: // Add the trigger for equality // NOTE: note that if the equality is true or false already, it might not be added - d_equalityEngine.addTriggerEquality(node); + d_equalityEngine->addTriggerEquality(node); break; case kind::SELECT: { // Invariant: array terms should be preregistered before being added to the equality engine - if (d_equalityEngine.hasTerm(node)) { + if (d_equalityEngine->hasTerm(node)) + { Assert(d_isPreRegistered.find(node) != d_isPreRegistered.end()); return; } // Reads - TNode store = d_equalityEngine.getRepresentative(node[0]); + TNode store = d_equalityEngine->getRepresentative(node[0]); // The may equal needs the store d_mayEqualEqualityEngine.addTerm(store); @@ -694,15 +715,15 @@ void TheoryArrays::preRegisterTermInternal(TNode node) if (node.getType().isArray()) { d_mayEqualEqualityEngine.addTerm(node); - d_equalityEngine.addTriggerTerm(node, THEORY_ARRAYS); + d_equalityEngine->addTriggerTerm(node, THEORY_ARRAYS); } else { - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); } Assert((d_isPreRegistered.insert(node), true)); - Assert(d_equalityEngine.getRepresentative(store) == store); + Assert(d_equalityEngine->getRepresentative(store) == store); d_infoMap.addIndex(store, node[1]); // Synchronize d_constReadsContext with SAT context @@ -712,7 +733,7 @@ void TheoryArrays::preRegisterTermInternal(TNode node) } // Record read in sharing data structure - TNode index = d_equalityEngine.getRepresentative(node[1]); + TNode index = d_equalityEngine->getRepresentative(node[1]); if (!options::arraysWeakEquivalence() && index.isConst()) { CTNodeList* temp; CNodeNListMap::iterator it = d_constReads.find(index); @@ -734,12 +755,13 @@ void TheoryArrays::preRegisterTermInternal(TNode node) break; } case kind::STORE: { - if (d_equalityEngine.hasTerm(node)) { + if (d_equalityEngine->hasTerm(node)) + { break; } - d_equalityEngine.addTriggerTerm(node, THEORY_ARRAYS); + d_equalityEngine->addTriggerTerm(node, THEORY_ARRAYS); - TNode a = d_equalityEngine.getRepresentative(node[0]); + TNode a = d_equalityEngine->getRepresentative(node[0]); if (node.isConst()) { // Can't use d_mayEqualEqualityEngine to merge node with a because they are both constants, @@ -761,12 +783,13 @@ void TheoryArrays::preRegisterTermInternal(TNode node) TNode v = node[2]; NodeManager* nm = NodeManager::currentNM(); Node ni = nm->mkNode(kind::SELECT, node, i); - if (!d_equalityEngine.hasTerm(ni)) { + if (!d_equalityEngine->hasTerm(ni)) + { preRegisterTermInternal(ni); } // Apply RIntro1 Rule - d_equalityEngine.assertEquality(ni.eqNode(v), true, d_true, d_reasonRow1); + d_equalityEngine->assertEquality(ni.eqNode(v), true, d_true, d_reasonRow1); d_infoMap.addStore(node, node); d_infoMap.addInStore(a, node); @@ -787,7 +810,8 @@ void TheoryArrays::preRegisterTermInternal(TNode node) break; } case kind::STORE_ALL: { - if (d_equalityEngine.hasTerm(node)) { + if (d_equalityEngine->hasTerm(node)) + { break; } ArrayStoreAll storeAll = node.getConst<ArrayStoreAll>(); @@ -798,7 +822,7 @@ void TheoryArrays::preRegisterTermInternal(TNode node) d_infoMap.setConstArr(node, node); d_mayEqualEqualityEngine.addTerm(node); Assert(d_mayEqualEqualityEngine.getRepresentative(node) == node); - d_equalityEngine.addTriggerTerm(node, THEORY_ARRAYS); + d_equalityEngine->addTriggerTerm(node, THEORY_ARRAYS); d_defValues[node] = defaultValue; break; } @@ -807,19 +831,19 @@ void TheoryArrays::preRegisterTermInternal(TNode node) if (node.getType().isArray()) { // The may equal needs the node d_mayEqualEqualityEngine.addTerm(node); - d_equalityEngine.addTriggerTerm(node, THEORY_ARRAYS); - Assert(d_equalityEngine.getSize(node) == 1); + d_equalityEngine->addTriggerTerm(node, THEORY_ARRAYS); + Assert(d_equalityEngine->getSize(node) == 1); } else { - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); } break; } // Invariant: preregistered terms are exactly the terms in the equality engine // Disabled, see comment above for kind::EQUAL - // Assert(d_equalityEngine.hasTerm(node) || - // !d_equalityEngine.consistent()); + // Assert(d_equalityEngine->hasTerm(node) || + // !d_equalityEngine->consistent()); } @@ -830,7 +854,7 @@ void TheoryArrays::preRegisterTerm(TNode node) // Note: do this here instead of in preRegisterTermInternal to prevent internal select // terms from being propagated out (as this results in an assertion failure). if (node.getKind() == kind::SELECT && node.getType().isBoolean()) { - d_equalityEngine.addTriggerPredicate(node); + d_equalityEngine->addTriggerPredicate(node); } } @@ -862,7 +886,7 @@ Node TheoryArrays::explain(TNode literal, eq::EqProof* proof) { void TheoryArrays::addSharedTerm(TNode t) { Debug("arrays::sharing") << spaces(getSatContext()->getLevel()) << "TheoryArrays::addSharedTerm(" << t << ")" << std::endl; - d_equalityEngine.addTriggerTerm(t, THEORY_ARRAYS); + d_equalityEngine->addTriggerTerm(t, THEORY_ARRAYS); if (t.getType().isArray()) { d_sharedArrays.insert(t); } @@ -876,12 +900,14 @@ void TheoryArrays::addSharedTerm(TNode t) { EqualityStatus TheoryArrays::getEqualityStatus(TNode a, TNode b) { - Assert(d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)); - if (d_equalityEngine.areEqual(a, b)) { + Assert(d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)); + if (d_equalityEngine->areEqual(a, b)) + { // The terms are implied to be equal return EQUALITY_TRUE; } - else if (d_equalityEngine.areDisequal(a, b, false)) { + else if (d_equalityEngine->areDisequal(a, b, false)) + { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -895,16 +921,19 @@ void TheoryArrays::checkPair(TNode r1, TNode r2) TNode x = r1[1]; TNode y = r2[1]; - Assert(d_equalityEngine.isTriggerTerm(x, THEORY_ARRAYS)); + Assert(d_equalityEngine->isTriggerTerm(x, THEORY_ARRAYS)); - if (d_equalityEngine.hasTerm(x) && d_equalityEngine.hasTerm(y) && - (d_equalityEngine.areEqual(x,y) || d_equalityEngine.areDisequal(x,y,false))) { + if (d_equalityEngine->hasTerm(x) && d_equalityEngine->hasTerm(y) + && (d_equalityEngine->areEqual(x, y) + || d_equalityEngine->areDisequal(x, y, false))) + { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): equality known, skipping" << std::endl; return; } // If the terms are already known to be equal, we are also in good shape - if (d_equalityEngine.areEqual(r1, r2)) { + if (d_equalityEngine->areEqual(r1, r2)) + { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): equal, skipping" << std::endl; return; } @@ -913,8 +942,9 @@ void TheoryArrays::checkPair(TNode r1, TNode r2) // If arrays are known to be disequal, or cannot become equal, we can continue Assert(d_mayEqualEqualityEngine.hasTerm(r1[0]) && d_mayEqualEqualityEngine.hasTerm(r2[0])); - if (r1[0].getType() != r2[0].getType() || - d_equalityEngine.areDisequal(r1[0], r2[0], false)) { + if (r1[0].getType() != r2[0].getType() + || d_equalityEngine->areDisequal(r1[0], r2[0], false)) + { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): arrays can't be equal, skipping" << std::endl; return; } @@ -923,14 +953,17 @@ void TheoryArrays::checkPair(TNode r1, TNode r2) } } - if (!d_equalityEngine.isTriggerTerm(y, THEORY_ARRAYS)) { + if (!d_equalityEngine->isTriggerTerm(y, THEORY_ARRAYS)) + { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): not connected to shared terms, skipping" << std::endl; return; } // Get representative trigger terms - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_ARRAYS); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_ARRAYS); + TNode x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_ARRAYS); + TNode y_shared = + d_equalityEngine->getTriggerTermRepresentative(y, THEORY_ARRAYS); EqualityStatus eqStatusDomain = d_valuation.getEqualityStatus(x_shared, y_shared); switch (eqStatusDomain) { case EQUALITY_TRUE_AND_PROPAGATED: @@ -999,14 +1032,16 @@ void TheoryArrays::computeCareGraph() TNode r1 = d_reads[i]; Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): checking read " << r1 << std::endl; - Assert(d_equalityEngine.hasTerm(r1)); + Assert(d_equalityEngine->hasTerm(r1)); TNode x = r1[1]; - if (!d_equalityEngine.isTriggerTerm(x, THEORY_ARRAYS)) { + if (!d_equalityEngine->isTriggerTerm(x, THEORY_ARRAYS)) + { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): not connected to shared terms, skipping" << std::endl; continue; } - Node x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_ARRAYS); + Node x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_ARRAYS); // Get the model value of index and find all reads that read from that same model value: these are the pairs we have to check // Also, insert this read in the list at the proper index @@ -1034,12 +1069,12 @@ void TheoryArrays::computeCareGraph() // We don't know the model value for x. Just do brute force examination of all pairs of reads for (unsigned j = 0; j < size; ++j) { TNode r2 = d_reads[j]; - Assert(d_equalityEngine.hasTerm(r2)); + Assert(d_equalityEngine->hasTerm(r2)); checkPair(r1,r2); } for (unsigned j = 0; j < d_constReadsList.size(); ++j) { TNode r2 = d_constReadsList[j]; - Assert(d_equalityEngine.hasTerm(r2)); + Assert(d_equalityEngine->hasTerm(r2)); checkPair(r1,r2); } } @@ -1064,7 +1099,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) NodeManager* nm = NodeManager::currentNM(); std::vector<Node> arrays; bool computeRep, isArray; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(&d_equalityEngine); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); for (; !eqcs_i.isFinished(); ++eqcs_i) { Node eqc = (*eqcs_i); isArray = eqc.getType().isArray(); @@ -1072,7 +1107,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) continue; } computeRep = false; - eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, &d_equalityEngine); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); for (; !eqc_i.isFinished(); ++eqc_i) { Node n = *eqc_i; // If this EC is an array type and it contains something other than STORE nodes, we have to compute a representative explicitly @@ -1095,30 +1130,36 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) bool changed; do { changed = false; - eqcs_i = eq::EqClassesIterator(&d_equalityEngine); + eqcs_i = eq::EqClassesIterator(d_equalityEngine); for (; !eqcs_i.isFinished(); ++eqcs_i) { Node eqc = (*eqcs_i); - eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, &d_equalityEngine); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); for (; !eqc_i.isFinished(); ++eqc_i) { Node n = *eqc_i; if (n.getKind() == kind::SELECT && termSet.find(n) != termSet.end()) { // Find all terms equivalent to n[0] and get corresponding read terms - Node array_eqc = d_equalityEngine.getRepresentative(n[0]); - eq::EqClassIterator array_eqc_i = eq::EqClassIterator(array_eqc, &d_equalityEngine); + Node array_eqc = d_equalityEngine->getRepresentative(n[0]); + eq::EqClassIterator array_eqc_i = + eq::EqClassIterator(array_eqc, d_equalityEngine); for (; !array_eqc_i.isFinished(); ++array_eqc_i) { Node arr = *array_eqc_i; - if (arr.getKind() == kind::STORE && - termSet.find(arr) != termSet.end() && - !d_equalityEngine.areEqual(arr[1],n[1])) { + if (arr.getKind() == kind::STORE + && termSet.find(arr) != termSet.end() + && !d_equalityEngine->areEqual(arr[1], n[1])) + { Node r = nm->mkNode(kind::SELECT, arr, n[1]); - if (termSet.find(r) == termSet.end() && d_equalityEngine.hasTerm(r)) { + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(a) read: " << r << endl; termSet.insert(r); changed = true; } r = nm->mkNode(kind::SELECT, arr[0], n[1]); - if (termSet.find(r) == termSet.end() && d_equalityEngine.hasTerm(r)) { + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(b) read: " << r << endl; termSet.insert(r); changed = true; @@ -1132,16 +1173,21 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) for(; it < instores->size(); ++it) { TNode instore = (*instores)[it]; Assert(instore.getKind() == kind::STORE); - if (termSet.find(instore) != termSet.end() && - !d_equalityEngine.areEqual(instore[1],n[1])) { + if (termSet.find(instore) != termSet.end() + && !d_equalityEngine->areEqual(instore[1], n[1])) + { Node r = nm->mkNode(kind::SELECT, instore, n[1]); - if (termSet.find(r) == termSet.end() && d_equalityEngine.hasTerm(r)) { + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(c) read: " << r << endl; termSet.insert(r); changed = true; } r = nm->mkNode(kind::SELECT, instore[0], n[1]); - if (termSet.find(r) == termSet.end() && d_equalityEngine.hasTerm(r)) { + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(d) read: " << r << endl; termSet.insert(r); changed = true; @@ -1154,7 +1200,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) } while (changed); // Send the equality engine information to the model - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { return false; } @@ -1166,7 +1212,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) Node n = *set_it; // If this term is a select, record that the EC rep of its store parameter is being read from using this term if (n.getKind() == kind::SELECT) { - selects[d_equalityEngine.getRepresentative(n[0])].push_back(n); + selects[d_equalityEngine->getRepresentative(n[0])].push_back(n); } } @@ -1177,7 +1223,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) // Compute all default values already in use //if (fullModel) { for (size_t i=0; i<arrays.size(); ++i) { - TNode nrep = d_equalityEngine.getRepresentative(arrays[i]); + TNode nrep = d_equalityEngine->getRepresentative(arrays[i]); d_mayEqualEqualityEngine.addTerm(nrep); // add the term in case it isn't there already TNode mayRep = d_mayEqualEqualityEngine.getRepresentative(nrep); it = d_defValues.find(mayRep); @@ -1190,7 +1236,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) // Loop through all array equivalence classes that need a representative computed for (size_t i=0; i<arrays.size(); ++i) { TNode n = arrays[i]; - TNode nrep = d_equalityEngine.getRepresentative(n); + TNode nrep = d_equalityEngine->getRepresentative(n); //if (fullModel) { // Compute default value for this array - there is one default value for every mayEqual equivalence class @@ -1280,9 +1326,9 @@ Node TheoryArrays::getSkolem(TNode ref, const string& name, const TypeNode& type } else { skolem = (*it).second; - if (d_equalityEngine.hasTerm(ref) && - d_equalityEngine.hasTerm(skolem) && - d_equalityEngine.areEqual(ref, skolem)) { + if (d_equalityEngine->hasTerm(ref) && d_equalityEngine->hasTerm(skolem) + && d_equalityEngine->areEqual(ref, skolem)) + { makeEqual = false; } } @@ -1294,7 +1340,7 @@ Node TheoryArrays::getSkolem(TNode ref, const string& name, const TypeNode& type if (makeEqual) { Node d = skolem.eqNode(ref); Debug("arrays-model-based") << "Asserting skolem equality " << d << endl; - d_equalityEngine.assertEquality(d, true, d_true); + d_equalityEngine->assertEquality(d, true, d_true); Assert(!d_conflict); d_skolemAssertions.push_back(d); d_skolemIndex = d_skolemIndex + 1; @@ -1328,13 +1374,15 @@ void TheoryArrays::check(Effort e) { if (!assertion.d_isPreregistered) { if (atom.getKind() == kind::EQUAL) { - if (!d_equalityEngine.hasTerm(atom[0])) { + if (!d_equalityEngine->hasTerm(atom[0])) + { Assert(atom[0].isConst()); - d_equalityEngine.addTerm(atom[0]); + d_equalityEngine->addTerm(atom[0]); } - if (!d_equalityEngine.hasTerm(atom[1])) { + if (!d_equalityEngine->hasTerm(atom[1])) + { Assert(atom[1].isConst()); - d_equalityEngine.addTerm(atom[1]); + d_equalityEngine->addTerm(atom[1]); } } } @@ -1342,17 +1390,19 @@ void TheoryArrays::check(Effort e) { // Do the work switch (fact.getKind()) { case kind::EQUAL: - d_equalityEngine.assertEquality(fact, true, fact); + d_equalityEngine->assertEquality(fact, true, fact); break; case kind::SELECT: - d_equalityEngine.assertPredicate(fact, true, fact); + d_equalityEngine->assertPredicate(fact, true, fact); break; case kind::NOT: if (fact[0].getKind() == kind::SELECT) { - d_equalityEngine.assertPredicate(fact[0], false, fact); - } else if (!d_equalityEngine.areDisequal(fact[0][0], fact[0][1], false)) { + d_equalityEngine->assertPredicate(fact[0], false, fact); + } + else if (!d_equalityEngine->areDisequal(fact[0][0], fact[0][1], false)) + { // Assert the dis-equality - d_equalityEngine.assertEquality(fact[0], false, fact); + d_equalityEngine->assertEquality(fact[0], false, fact); // Apply ArrDiseq Rule if diseq is between arrays if(fact[0][0].getType().isArray() && !d_conflict) { @@ -1396,18 +1446,26 @@ void TheoryArrays::check(Effort e) { // when we output the lemma. However, in replay need the lemma to be propagated, and so we // preregister manually. if (d_proofsEnabled) { - if (!d_equalityEngine.hasTerm(ak)) { preRegisterTermInternal(ak); } - if (!d_equalityEngine.hasTerm(bk)) { preRegisterTermInternal(bk); } + if (!d_equalityEngine->hasTerm(ak)) + { + preRegisterTermInternal(ak); + } + if (!d_equalityEngine->hasTerm(bk)) + { + preRegisterTermInternal(bk); + } } - if (options::arraysPropagate() > 0 && d_equalityEngine.hasTerm(ak) && d_equalityEngine.hasTerm(bk)) { + if (options::arraysPropagate() > 0 && d_equalityEngine->hasTerm(ak) + && d_equalityEngine->hasTerm(bk)) + { // Propagate witness disequality - might produce a conflict d_permRef.push_back(lemma); Debug("pf::array") << "Asserting to the equality engine:" << std::endl << "\teq = " << eq << std::endl << "\treason = " << fact << std::endl; - d_equalityEngine.assertEquality(eq, false, fact, d_reasonExt); + d_equalityEngine->assertEquality(eq, false, fact, d_reasonExt); ++d_numProp; } @@ -1465,7 +1523,7 @@ void TheoryArrays::check(Effort e) { // Find the bucket for this read. mayRep = d_mayEqualEqualityEngine.getRepresentative(r[0]); - iRep = d_equalityEngine.getRepresentative(r[1]); + iRep = d_equalityEngine->getRepresentative(r[1]); std::pair<TNode, TNode> key(mayRep, iRep); ReadBucketMap::iterator rbm_it = d_readBucketTable.find(key); if (rbm_it == d_readBucketTable.end()) @@ -1484,20 +1542,21 @@ void TheoryArrays::check(Effort e) { const TNode& r2 = *ctnl_it; Assert(r2.getKind() == kind::SELECT); Assert(mayRep == d_mayEqualEqualityEngine.getRepresentative(r2[0])); - Assert(iRep == d_equalityEngine.getRepresentative(r2[1])); - if (d_equalityEngine.areEqual(r, r2)) { + Assert(iRep == d_equalityEngine->getRepresentative(r2[1])); + if (d_equalityEngine->areEqual(r, r2)) + { continue; } if (weakEquivGetRepIndex(r[0], r[1]) == weakEquivGetRepIndex(r2[0], r[1])) { // add lemma: r[1] = r2[1] /\ cond(r[0],r2[0]) => r = r2 vector<TNode> conjunctions; - Assert(d_equalityEngine.areEqual(r, Rewriter::rewrite(r))); - Assert(d_equalityEngine.areEqual(r2, Rewriter::rewrite(r2))); + Assert(d_equalityEngine->areEqual(r, Rewriter::rewrite(r))); + Assert(d_equalityEngine->areEqual(r2, Rewriter::rewrite(r2))); Node lemma = Rewriter::rewrite(r).eqNode(Rewriter::rewrite(r2)).negate(); d_permRef.push_back(lemma); conjunctions.push_back(lemma); if (r[1] != r2[1]) { - d_equalityEngine.explainEquality(r[1], r2[1], true, conjunctions); + d_equalityEngine->explainEquality(r[1], r2[1], true, conjunctions); } // TODO: get smaller lemmas by eliminating shared parts of path weakEquivBuildCond(r[0], r[1], conjunctions); @@ -1648,8 +1707,8 @@ void TheoryArrays::mergeArrays(TNode a, TNode b) // Normally, a is its own representative, but it's possible for a to have // been merged with another array after it got queued up by the equality engine, // so we take its representative to be safe. - a = d_equalityEngine.getRepresentative(a); - Assert(d_equalityEngine.getRepresentative(b) == a); + a = d_equalityEngine->getRepresentative(a); + Assert(d_equalityEngine->getRepresentative(b) == a); Trace("arrays-merge") << spaces(getSatContext()->getLevel()) << "Arrays::merge: (" << a << ", " << b << ")\n"; if (options::arraysOptimizeLinear() && !options::arraysWeakEquivalence()) { @@ -1759,7 +1818,7 @@ void TheoryArrays::checkStore(TNode a) { TNode b = a[0]; TNode i = a[1]; - TNode brep = d_equalityEngine.getRepresentative(b); + TNode brep = d_equalityEngine->getRepresentative(b); if (!options::arraysOptimizeLinear() || d_infoMap.isNonLinear(brep)) { const CTNodeList* js = d_infoMap.getIndices(brep); @@ -1786,17 +1845,18 @@ void TheoryArrays::checkRowForIndex(TNode i, TNode a) d_infoMap.getInfo(a)->print(); } Assert(a.getType().isArray()); - Assert(d_equalityEngine.getRepresentative(a) == a); + Assert(d_equalityEngine->getRepresentative(a) == a); TNode constArr = d_infoMap.getConstArr(a); if (!constArr.isNull()) { ArrayStoreAll storeAll = constArr.getConst<ArrayStoreAll>(); Node defValue = storeAll.getValue(); Node selConst = NodeManager::currentNM()->mkNode(kind::SELECT, constArr, i); - if (!d_equalityEngine.hasTerm(selConst)) { + if (!d_equalityEngine->hasTerm(selConst)) + { preRegisterTermInternal(selConst); } - d_equalityEngine.assertEquality(selConst.eqNode(defValue), true, d_true); + d_equalityEngine->assertEquality(selConst.eqNode(defValue), true, d_true); } const CTNodeList* stores = d_infoMap.getStores(a); @@ -1848,7 +1908,8 @@ void TheoryArrays::checkRowLemmas(TNode a, TNode b) for( ; it < i_a->size(); ++it) { TNode i = (*i_a)[it]; Node selConst = NodeManager::currentNM()->mkNode(kind::SELECT, constArr, i); - if (!d_equalityEngine.hasTerm(selConst)) { + if (!d_equalityEngine->hasTerm(selConst)) + { preRegisterTermInternal(selConst); } } @@ -1901,8 +1962,8 @@ void TheoryArrays::propagate(RowLemmaType lem) std::tie(a, b, i, j) = lem; Assert(a.getType().isArray() && b.getType().isArray()); - if (d_equalityEngine.areEqual(a,b) || - d_equalityEngine.areEqual(i,j)) { + if (d_equalityEngine->areEqual(a, b) || d_equalityEngine->areEqual(i, j)) + { return; } @@ -1911,14 +1972,15 @@ void TheoryArrays::propagate(RowLemmaType lem) Node bj = nm->mkNode(kind::SELECT, b, j); // Try to avoid introducing new read terms: track whether these already exist - bool ajExists = d_equalityEngine.hasTerm(aj); - bool bjExists = d_equalityEngine.hasTerm(bj); + bool ajExists = d_equalityEngine->hasTerm(aj); + bool bjExists = d_equalityEngine->hasTerm(bj); bool bothExist = ajExists && bjExists; // If propagating, check propagations int prop = options::arraysPropagate(); if (prop > 0) { - if (d_equalityEngine.areDisequal(i,j,true) && (bothExist || prop > 1)) { + if (d_equalityEngine->areDisequal(i, j, true) && (bothExist || prop > 1)) + { Trace("arrays-lem") << spaces(getSatContext()->getLevel()) <<"Arrays::queueRowLemma: propagating aj = bj ("<<aj<<", "<<bj<<")\n"; Node aj_eq_bj = aj.eqNode(bj); Node reason = @@ -1930,17 +1992,18 @@ void TheoryArrays::propagate(RowLemmaType lem) if (!bjExists) { preRegisterTermInternal(bj); } - d_equalityEngine.assertEquality(aj_eq_bj, true, reason, d_reasonRow); + d_equalityEngine->assertEquality(aj_eq_bj, true, reason, d_reasonRow); ++d_numProp; return; } - if (bothExist && d_equalityEngine.areDisequal(aj,bj,true)) { + if (bothExist && d_equalityEngine->areDisequal(aj, bj, true)) + { Trace("arrays-lem") << spaces(getSatContext()->getLevel()) <<"Arrays::queueRowLemma: propagating i = j ("<<i<<", "<<j<<")\n"; Node reason = (aj.isConst() && bj.isConst()) ? d_true : aj.eqNode(bj).notNode(); Node i_eq_j = i.eqNode(j); d_permRef.push_back(reason); - d_equalityEngine.assertEquality(i_eq_j, true, reason, d_reasonRow); + d_equalityEngine->assertEquality(i_eq_j, true, reason, d_reasonRow); ++d_numProp; return; } @@ -1958,8 +2021,8 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) std::tie(a, b, i, j) = lem; Assert(a.getType().isArray() && b.getType().isArray()); - if (d_equalityEngine.areEqual(a,b) || - d_equalityEngine.areEqual(i,j)) { + if (d_equalityEngine->areEqual(a, b) || d_equalityEngine->areEqual(i, j)) + { return; } @@ -1968,8 +2031,8 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) Node bj = nm->mkNode(kind::SELECT, b, j); // Try to avoid introducing new read terms: track whether these already exist - bool ajExists = d_equalityEngine.hasTerm(aj); - bool bjExists = d_equalityEngine.hasTerm(bj); + bool ajExists = d_equalityEngine->hasTerm(aj); + bool bjExists = d_equalityEngine->hasTerm(bj); bool bothExist = ajExists && bjExists; // If propagating, check propagations @@ -1981,13 +2044,16 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) // If equivalent lemma already exists, don't enqueue this one if (d_useArrTable) { Node tableEntry = NodeManager::currentNM()->mkNode(kind::ARR_TABLE_FUN, a, b, i, j); - if (d_equalityEngine.getSize(tableEntry) != 1) { + if (d_equalityEngine->getSize(tableEntry) != 1) + { return; } } // Prefer equality between indexes so as not to introduce new read terms - if (options::arraysEagerIndexSplitting() && !bothExist && !d_equalityEngine.areDisequal(i,j, false)) { + if (options::arraysEagerIndexSplitting() && !bothExist + && !d_equalityEngine->areDisequal(i, j, false)) + { Node i_eq_j; if (!d_proofsEnabled) { i_eq_j = d_valuation.ensureLiteral(i.eqNode(j)); // TODO: think about this @@ -2008,20 +2074,22 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) if (!ajExists) { preRegisterTermInternal(aj); } - if (!d_equalityEngine.hasTerm(aj2)) { + if (!d_equalityEngine->hasTerm(aj2)) + { preRegisterTermInternal(aj2); } - d_equalityEngine.assertEquality(aj.eqNode(aj2), true, d_true); + d_equalityEngine->assertEquality(aj.eqNode(aj2), true, d_true); } Node bj2 = Rewriter::rewrite(bj); if (bj != bj2) { if (!bjExists) { preRegisterTermInternal(bj); } - if (!d_equalityEngine.hasTerm(bj2)) { + if (!d_equalityEngine->hasTerm(bj2)) + { preRegisterTermInternal(bj2); } - d_equalityEngine.assertEquality(bj.eqNode(bj2), true, d_true); + d_equalityEngine->assertEquality(bj.eqNode(bj2), true, d_true); } if (aj2 == bj2) { return; @@ -2031,20 +2099,22 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) Node eq1 = aj2.eqNode(bj2); Node eq1_r = Rewriter::rewrite(eq1); if (eq1_r == d_true) { - if (!d_equalityEngine.hasTerm(aj2)) { + if (!d_equalityEngine->hasTerm(aj2)) + { preRegisterTermInternal(aj2); } - if (!d_equalityEngine.hasTerm(bj2)) { + if (!d_equalityEngine->hasTerm(bj2)) + { preRegisterTermInternal(bj2); } - d_equalityEngine.assertEquality(eq1, true, d_true); + d_equalityEngine->assertEquality(eq1, true, d_true); return; } Node eq2 = i.eqNode(j); Node eq2_r = Rewriter::rewrite(eq2); if (eq2_r == d_true) { - d_equalityEngine.assertEquality(eq2, true, d_true); + d_equalityEngine->assertEquality(eq2, true, d_true); return; } @@ -2089,14 +2159,16 @@ bool TheoryArrays::dischargeLemmas() NodeManager* nm = NodeManager::currentNM(); Node aj = nm->mkNode(kind::SELECT, a, j); Node bj = nm->mkNode(kind::SELECT, b, j); - bool ajExists = d_equalityEngine.hasTerm(aj); - bool bjExists = d_equalityEngine.hasTerm(bj); + bool ajExists = d_equalityEngine->hasTerm(aj); + bool bjExists = d_equalityEngine->hasTerm(bj); // Check for redundant lemma // TODO: more checks possible (i.e. check d_RowAlreadyAdded in context) - if (!d_equalityEngine.hasTerm(i) || !d_equalityEngine.hasTerm(j) || d_equalityEngine.areEqual(i,j) || - !d_equalityEngine.hasTerm(a) || !d_equalityEngine.hasTerm(b) || d_equalityEngine.areEqual(a,b) || - (ajExists && bjExists && d_equalityEngine.areEqual(aj,bj))) { + if (!d_equalityEngine->hasTerm(i) || !d_equalityEngine->hasTerm(j) + || d_equalityEngine->areEqual(i, j) || !d_equalityEngine->hasTerm(a) + || !d_equalityEngine->hasTerm(b) || d_equalityEngine->areEqual(a, b) + || (ajExists && bjExists && d_equalityEngine->areEqual(aj, bj))) + { continue; } @@ -2114,21 +2186,22 @@ bool TheoryArrays::dischargeLemmas() if (!ajExists) { preRegisterTermInternal(aj); } - if (!d_equalityEngine.hasTerm(aj2)) { + if (!d_equalityEngine->hasTerm(aj2)) + { preRegisterTermInternal(aj2); } - d_equalityEngine.assertEquality(aj.eqNode(aj2), true, d_true); + d_equalityEngine->assertEquality(aj.eqNode(aj2), true, d_true); } Node bj2 = Rewriter::rewrite(bj); if (bj != bj2) { if (!bjExists) { preRegisterTermInternal(bj); } - if (!d_equalityEngine.hasTerm(bj2)) { + if (!d_equalityEngine->hasTerm(bj2)) + { preRegisterTermInternal(bj2); } - d_equalityEngine.assertEquality(bj.eqNode(bj2), true, d_true); - + d_equalityEngine->assertEquality(bj.eqNode(bj2), true, d_true); } if (aj2 == bj2) { continue; @@ -2138,20 +2211,22 @@ bool TheoryArrays::dischargeLemmas() Node eq1 = aj2.eqNode(bj2); Node eq1_r = Rewriter::rewrite(eq1); if (eq1_r == d_true) { - if (!d_equalityEngine.hasTerm(aj2)) { + if (!d_equalityEngine->hasTerm(aj2)) + { preRegisterTermInternal(aj2); } - if (!d_equalityEngine.hasTerm(bj2)) { + if (!d_equalityEngine->hasTerm(bj2)) + { preRegisterTermInternal(bj2); } - d_equalityEngine.assertEquality(eq1, true, d_true); + d_equalityEngine->assertEquality(eq1, true, d_true); continue; } Node eq2 = i.eqNode(j); Node eq2_r = Rewriter::rewrite(eq2); if (eq2_r == d_true) { - d_equalityEngine.assertEquality(eq2, true, d_true); + d_equalityEngine->assertEquality(eq2, true, d_true); continue; } diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h index 116b0f43b..f1cd2ea14 100644 --- a/src/theory/arrays/theory_arrays.h +++ b/src/theory/arrays/theory_arrays.h @@ -148,9 +148,18 @@ class TheoryArrays : public Theory { std::string name = ""); ~TheoryArrays(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ + void finishInit() override; + //--------------------------------- end initialization std::string identify() const override { return std::string("TheoryArrays"); } @@ -353,9 +362,6 @@ class TheoryArrays : public Theory { /** The notify class for d_equalityEngine */ NotifyClass d_notify; - /** Equaltity engine */ - eq::EqualityEngine d_equalityEngine; - /** Are we in conflict? */ context::CDO<bool> d_conflict; @@ -460,7 +466,7 @@ class TheoryArrays : public Theory { int d_topLevel; /** An equality-engine callback for proof reconstruction */ - ArrayProofReconstruction d_proofReconstruction; + std::unique_ptr<ArrayProofReconstruction> d_proofReconstruction; /** * The decision strategy for the theory of arrays, which calls the @@ -493,9 +499,6 @@ class TheoryArrays : public Theory { */ Node getNextDecisionRequest(); - public: - eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } - };/* class TheoryArrays */ }/* CVC4::theory::arrays namespace */ diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp index c49909fe6..48ec81a1e 100644 --- a/src/theory/bv/bv_subtheory_core.cpp +++ b/src/theory/bv/bv_subtheory_core.cpp @@ -35,55 +35,65 @@ using namespace CVC4::theory::bv::utils; CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt) : SubtheorySolver(c, bv), d_notify(*this), - d_equalityEngine(d_notify, c, "theory::bv::ee", true), d_slicer(new Slicer()), d_isComplete(c, true), d_lemmaThreshold(16), d_useSlicer(false), d_preregisterCalled(false), d_checkCalled(false), + d_bv(bv), d_extTheory(extt), d_reasons(c) { - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::BITVECTOR_CONCAT, true); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_AND); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_OR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XOR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NAND); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XNOR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_COMP); - d_equalityEngine.addFunctionKind(kind::BITVECTOR_MULT, true); - d_equalityEngine.addFunctionKind(kind::BITVECTOR_PLUS, true); - d_equalityEngine.addFunctionKind(kind::BITVECTOR_EXTRACT, true); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SUB); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NEG); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UDIV); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UREM); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SDIV); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SREM); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SMOD); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SHL); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_LSHR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ASHR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULE); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGE); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLE); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGE); - d_equalityEngine.addFunctionKind(kind::BITVECTOR_TO_NAT); - d_equalityEngine.addFunctionKind(kind::INT_TO_BITVECTOR); } CoreSolver::~CoreSolver() {} -void CoreSolver::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); +bool CoreSolver::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::bv::ee"; + return true; +} + +void CoreSolver::finishInit() +{ + // use the parent's equality engine, which may be the one we allocated above + d_equalityEngine = d_bv->getEqualityEngine(); + + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::BITVECTOR_CONCAT, true); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_AND); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_OR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_XOR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NOT); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NAND); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NOR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_XNOR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_COMP); + d_equalityEngine->addFunctionKind(kind::BITVECTOR_MULT, true); + d_equalityEngine->addFunctionKind(kind::BITVECTOR_PLUS, true); + d_equalityEngine->addFunctionKind(kind::BITVECTOR_EXTRACT, true); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SUB); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NEG); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UDIV); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UREM); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SDIV); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SREM); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SMOD); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SHL); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_LSHR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_ASHR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_ULT); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_ULE); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UGT); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UGE); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SLT); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SLE); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SGT); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SGE); + d_equalityEngine->addFunctionKind(kind::BITVECTOR_TO_NAT); + d_equalityEngine->addFunctionKind(kind::INT_TO_BITVECTOR); } void CoreSolver::enableSlicer() { @@ -95,13 +105,14 @@ void CoreSolver::enableSlicer() { void CoreSolver::preRegister(TNode node) { d_preregisterCalled = true; if (node.getKind() == kind::EQUAL) { - d_equalityEngine.addTriggerEquality(node); - if (d_useSlicer) { - d_slicer->processEquality(node); - AlwaysAssert(!d_checkCalled); + d_equalityEngine->addTriggerEquality(node); + if (d_useSlicer) + { + d_slicer->processEquality(node); + AlwaysAssert(!d_checkCalled); } } else { - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); // Register with the extended theory, for context-dependent simplification. // Notice we do this for registered terms but not internally generated // equivalence classes. The two should roughly cooincide. Since ExtTheory is @@ -115,9 +126,9 @@ void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) { bool polarity = literal.getKind() != kind::NOT; TNode atom = polarity ? literal : literal[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions); } else { - d_equalityEngine.explainPredicate(atom, polarity, assumptions); + d_equalityEngine->explainPredicate(atom, polarity, assumptions); } } @@ -224,14 +235,14 @@ void CoreSolver::buildModel() TNodeSet constants; TNodeSet constants_in_eq_engine; // collect constants in equality engine - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(&d_equalityEngine); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while (!eqcs_i.isFinished()) { TNode repr = *eqcs_i; if (repr.getKind() == kind::CONST_BITVECTOR) { // must check if it's just the constant - eq::EqClassIterator it(repr, &d_equalityEngine); + eq::EqClassIterator it(repr, d_equalityEngine); if (!(++it).isFinished() || true) { constants.insert(repr); @@ -243,7 +254,7 @@ void CoreSolver::buildModel() // build repr to value map - eqcs_i = eq::EqClassesIterator(&d_equalityEngine); + eqcs_i = eq::EqClassesIterator(d_equalityEngine); while (!eqcs_i.isFinished()) { TNode repr = *eqcs_i; @@ -351,15 +362,16 @@ bool CoreSolver::assertFactToEqualityEngine(TNode fact, TNode reason) { if (predicate.getKind() == kind::EQUAL) { if (negated) { // dis-equality - d_equalityEngine.assertEquality(predicate, false, reason); + d_equalityEngine->assertEquality(predicate, false, reason); } else { // equality - d_equalityEngine.assertEquality(predicate, true, reason); + d_equalityEngine->assertEquality(predicate, true, reason); } } else { // Adding predicate if the congruence over it is turned on - if (d_equalityEngine.isFunctionKind(predicate.getKind())) { - d_equalityEngine.assertPredicate(predicate, !negated, reason); + if (d_equalityEngine->isFunctionKind(predicate.getKind())) + { + d_equalityEngine->assertPredicate(predicate, !negated, reason); } } } @@ -408,7 +420,7 @@ bool CoreSolver::storePropagation(TNode literal) { void CoreSolver::conflict(TNode a, TNode b) { std::vector<TNode> assumptions; - d_equalityEngine.explainEquality(a, b, true, assumptions); + d_equalityEngine->explainEquality(a, b, true, assumptions); Node conflict = flattenAnd(assumptions); d_bv->setConflict(conflict); } @@ -434,7 +446,7 @@ bool CoreSolver::collectModelInfo(TheoryModel* m, bool fullModel) } set<Node> termSet; d_bv->computeRelevantTerms(termSet); - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { return false; } @@ -457,7 +469,7 @@ bool CoreSolver::collectModelInfo(TheoryModel* m, bool fullModel) Node CoreSolver::getModelValue(TNode var) { Debug("bitvector-model") << "CoreSolver::getModelValue (" << var <<")"; Assert(isComplete()); - TNode repr = d_equalityEngine.getRepresentative(var); + TNode repr = d_equalityEngine->getRepresentative(var); Node result = Node(); if (repr.getKind() == kind::CONST_BITVECTOR) { result = repr; @@ -472,6 +484,35 @@ Node CoreSolver::getModelValue(TNode var) { return result; } +void CoreSolver::addSharedTerm(TNode t) +{ + d_equalityEngine->addTriggerTerm(t, THEORY_BV); +} + +EqualityStatus CoreSolver::getEqualityStatus(TNode a, TNode b) +{ + if (d_equalityEngine->areEqual(a, b)) + { + // The terms are implied to be equal + return EQUALITY_TRUE; + } + if (d_equalityEngine->areDisequal(a, b, false)) + { + // The terms are implied to be dis-equal + return EQUALITY_FALSE; + } + return EQUALITY_UNKNOWN; +} + +bool CoreSolver::hasTerm(TNode node) const +{ + return d_equalityEngine->hasTerm(node); +} +void CoreSolver::addTermToEqualityEngine(TNode node) +{ + d_equalityEngine->addTerm(node); +} + CoreSolver::Statistics::Statistics() : d_numCallstoCheck("theory::bv::CoreSolver::NumCallsToCheck", 0) , d_slicerEnabled("theory::bv::CoreSolver::SlicerEnabled", false) diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h index ea652e7cd..33f119e5f 100644 --- a/src/theory/bv/bv_subtheory_core.h +++ b/src/theory/bv/bv_subtheory_core.h @@ -70,9 +70,6 @@ class CoreSolver : public SubtheorySolver { /** The notify class for d_equalityEngine */ NotifyClass d_notify; - /** Equality engine */ - eq::EqualityEngine d_equalityEngine; - /** Store a propagation to the bv solver */ bool storePropagation(TNode literal); @@ -88,6 +85,10 @@ class CoreSolver : public SubtheorySolver { bool d_preregisterCalled; bool d_checkCalled; + /** Pointer to the parent theory solver that owns this */ + TheoryBV* d_bv; + /** Pointer to the equality engine of the parent */ + eq::EqualityEngine* d_equalityEngine; /** Pointer to the extended theory module. */ ExtTheory* d_extTheory; @@ -100,36 +101,23 @@ class CoreSolver : public SubtheorySolver { Node getBaseDecomposition(TNode a); bool isCompleteForTerm(TNode term, TNodeBoolMap& seen); Statistics d_statistics; -public: - CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt); - ~CoreSolver(); - bool isComplete() override { return d_isComplete; } - void setMasterEqualityEngine(eq::EqualityEngine* eq); - void preRegister(TNode node) override; - bool check(Theory::Effort e) override; - void explain(TNode literal, std::vector<TNode>& assumptions) override; - bool collectModelInfo(TheoryModel* m, bool fullModel) override; - Node getModelValue(TNode var) override; - void addSharedTerm(TNode t) override - { - d_equalityEngine.addTriggerTerm(t, THEORY_BV); - } - EqualityStatus getEqualityStatus(TNode a, TNode b) override - { - if (d_equalityEngine.areEqual(a, b)) { - // The terms are implied to be equal - return EQUALITY_TRUE; - } - if (d_equalityEngine.areDisequal(a, b, false)) { - // The terms are implied to be dis-equal - return EQUALITY_FALSE; - } - return EQUALITY_UNKNOWN; - } - bool hasTerm(TNode node) const { return d_equalityEngine.hasTerm(node); } - void addTermToEqualityEngine(TNode node) { d_equalityEngine.addTerm(node); } + + public: + CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt); + ~CoreSolver(); + bool needsEqualityEngine(EeSetupInfo& esi); + void finishInit(); + bool isComplete() override { return d_isComplete; } + void preRegister(TNode node) override; + bool check(Theory::Effort e) override; + void explain(TNode literal, std::vector<TNode>& assumptions) override; + bool collectModelInfo(TheoryModel* m, bool fullModel) override; + Node getModelValue(TNode var) override; + void addSharedTerm(TNode t) override; + EqualityStatus getEqualityStatus(TNode a, TNode b) override; + bool hasTerm(TNode node) const; + void addTermToEqualityEngine(TNode node); void enableSlicer(); - eq::EqualityEngine * getEqualityEngine() { return &d_equalityEngine; } }; diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 0a4499c11..ced320d92 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -113,13 +113,31 @@ TheoryBV::TheoryBV(context::Context* c, TheoryBV::~TheoryBV() {} -void TheoryBV::setMasterEqualityEngine(eq::EqualityEngine* eq) { - if (options::bitblastMode() == options::BitblastMode::EAGER) +TheoryRewriter* TheoryBV::getTheoryRewriter() { return &d_rewriter; } + +bool TheoryBV::needsEqualityEngine(EeSetupInfo& esi) +{ + CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE]; + if (core) { - return; + return core->needsEqualityEngine(esi); } - if (options::bitvectorEqualitySolver()) { - dynamic_cast<CoreSolver*>(d_subtheoryMap[SUB_CORE])->setMasterEqualityEngine(eq); + // otherwise we don't use an equality engine + return false; +} + +void TheoryBV::finishInit() +{ + // these kinds are semi-evaluated in getModelValue (applications of this + // kind are treated as variables) + d_valuation.setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UDIV); + d_valuation.setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UREM); + + CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE]; + if (core) + { + // must finish initialization in the core solver + core->finishInit(); } } @@ -185,16 +203,6 @@ Node TheoryBV::getBVDivByZero(Kind k, unsigned width) { Unreachable(); } -void TheoryBV::finishInit() -{ - // these kinds are semi-evaluated in getModelValue (applications of this - // kind are treated as variables) - TheoryModel* tm = d_valuation.getModel(); - Assert(tm != nullptr); - tm->setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UDIV); - tm->setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UREM); -} - TrustNode TheoryBV::expandDefinition(Node node) { Debug("bitvector-expandDefinition") << "TheoryBV::expandDefinition(" << node << ")" << std::endl; @@ -582,16 +590,6 @@ void TheoryBV::propagate(Effort e) { } } - -eq::EqualityEngine * TheoryBV::getEqualityEngine() { - CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE]; - if( core ){ - return core->getEqualityEngine(); - }else{ - return NULL; - } -} - bool TheoryBV::getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) { eq::EqualityEngine * ee = getEqualityEngine(); if( ee ){ diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index b0991c8b0..0e8877359 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -77,11 +77,18 @@ class TheoryBV : public Theory { ~TheoryBV(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ void finishInit() override; + //--------------------------------- end initialization TrustNode expandDefinition(Node node) override; @@ -99,8 +106,6 @@ class TheoryBV : public Theory { std::string identify() const override { return std::string("TheoryBV"); } - /** equality engine */ - eq::EqualityEngine* getEqualityEngine() override; bool getCurrentSubstitution(int effort, std::vector<Node>& vars, std::vector<Node>& subs, diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 832324d4b..4b38ad6bd 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -51,7 +51,6 @@ TheoryDatatypes::TheoryDatatypes(Context* c, d_infer_exp(c), d_term_sk(u), d_notify(*this), - d_equalityEngine(d_notify, c, "theory::datatypes", true), d_labels(c), d_selector_apps(c), d_conflict(c, false), @@ -64,13 +63,6 @@ TheoryDatatypes::TheoryDatatypes(Context* c, d_lemmas_produced_c(u), d_sygusExtension(nullptr) { - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::APPLY_CONSTRUCTOR); - d_equalityEngine.addFunctionKind(kind::APPLY_SELECTOR_TOTAL); - //d_equalityEngine.addFunctionKind(kind::DT_SIZE); - //d_equalityEngine.addFunctionKind(kind::DT_HEIGHT_BOUND); - d_equalityEngine.addFunctionKind(kind::APPLY_TESTER); - //d_equalityEngine.addFunctionKind(kind::APPLY_UF); d_true = NodeManager::currentNM()->mkConst( true ); d_zero = NodeManager::currentNM()->mkConst( Rational(0) ); @@ -86,8 +78,32 @@ TheoryDatatypes::~TheoryDatatypes() { } } -void TheoryDatatypes::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); +TheoryRewriter* TheoryDatatypes::getTheoryRewriter() { return &d_rewriter; } + +bool TheoryDatatypes::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::datatypes::ee"; + return true; +} + +void TheoryDatatypes::finishInit() +{ + Assert(d_equalityEngine != nullptr); + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::APPLY_CONSTRUCTOR); + d_equalityEngine->addFunctionKind(kind::APPLY_SELECTOR_TOTAL); + d_equalityEngine->addFunctionKind(kind::APPLY_TESTER); + // We could but don't do congruence for DT_SIZE and DT_HEIGHT_BOUND here. + // It also could make sense in practice to do congruence for APPLY_UF, but + // this is not done. + if (getQuantifiersEngine() && options::sygus()) + { + d_sygusExtension.reset( + new SygusExtension(this, getQuantifiersEngine(), getSatContext())); + // do congruence on evaluation functions + d_equalityEngine->addFunctionKind(kind::DT_SYGUS_EVAL); + } } TheoryDatatypes::EqcInfo* TheoryDatatypes::getOrMakeEqcInfo( TNode n, bool doMake ){ @@ -193,7 +209,7 @@ void TheoryDatatypes::check(Effort e) { do { d_addedFact = false; std::map< TypeNode, Node > rec_singletons; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while( !eqcs_i.isFinished() ){ Node n = (*eqcs_i); //TODO : avoid irrelevant (pre-registered but not asserted) terms here? @@ -479,9 +495,9 @@ void TheoryDatatypes::assertFact( Node fact, Node exp ){ bool polarity = fact.getKind() != kind::NOT; TNode atom = polarity ? fact : fact[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.assertEquality( atom, polarity, exp ); + d_equalityEngine->assertEquality(atom, polarity, exp); }else{ - d_equalityEngine.assertPredicate( atom, polarity, exp ); + d_equalityEngine->assertPredicate(atom, polarity, exp); } doPendingMerges(); // could be sygus-specific @@ -527,37 +543,27 @@ void TheoryDatatypes::preRegisterTerm(TNode n) { switch (n.getKind()) { case kind::EQUAL: // Add the trigger for equality - d_equalityEngine.addTriggerEquality(n); + d_equalityEngine->addTriggerEquality(n); break; case kind::APPLY_TESTER: // Get triggered for both equal and dis-equal - d_equalityEngine.addTriggerPredicate(n); + d_equalityEngine->addTriggerPredicate(n); break; default: // Function applications/predicates - d_equalityEngine.addTerm(n); + d_equalityEngine->addTerm(n); if (d_sygusExtension) { std::vector< Node > lemmas; d_sygusExtension->preRegisterTerm(n, lemmas); doSendLemmas( lemmas ); } - //d_equalityEngine.addTriggerTerm(n, THEORY_DATATYPES); + // d_equalityEngine->addTriggerTerm(n, THEORY_DATATYPES); break; } flushPendingFacts(); } -void TheoryDatatypes::finishInit() { - if (getQuantifiersEngine() && options::sygus()) - { - d_sygusExtension.reset( - new SygusExtension(this, getQuantifiersEngine(), getSatContext())); - // do congruence on evaluation functions - d_equalityEngine.addFunctionKind(kind::DT_SYGUS_EVAL); - } -} - TrustNode TheoryDatatypes::expandDefinition(Node n) { NodeManager* nm = NodeManager::currentNM(); @@ -727,7 +733,7 @@ TrustNode TheoryDatatypes::ppRewrite(TNode in) void TheoryDatatypes::addSharedTerm(TNode t) { Debug("datatypes") << "TheoryDatatypes::addSharedTerm(): " << t << " " << t.getType().isBoolean() << endl; - d_equalityEngine.addTriggerTerm(t, THEORY_DATATYPES); + d_equalityEngine->addTriggerTerm(t, THEORY_DATATYPES); Debug("datatypes") << "TheoryDatatypes::addSharedTerm() finished" << std::endl; } @@ -776,14 +782,14 @@ void TheoryDatatypes::addAssumptions( std::vector<TNode>& assumptions, std::vect void TheoryDatatypes::explainEquality( TNode a, TNode b, bool polarity, std::vector<TNode>& assumptions ) { if( a!=b ){ std::vector<TNode> tassumptions; - d_equalityEngine.explainEquality(a, b, polarity, tassumptions); + d_equalityEngine->explainEquality(a, b, polarity, tassumptions); addAssumptions( assumptions, tassumptions ); } } void TheoryDatatypes::explainPredicate( TNode p, bool polarity, std::vector<TNode>& assumptions ) { std::vector<TNode> tassumptions; - d_equalityEngine.explainPredicate(p, polarity, tassumptions); + d_equalityEngine->explainPredicate(p, polarity, tassumptions); addAssumptions( assumptions, tassumptions ); } @@ -1367,12 +1373,14 @@ void TheoryDatatypes::collapseSelector( Node s, Node c ) { } EqualityStatus TheoryDatatypes::getEqualityStatus(TNode a, TNode b){ - Assert(d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)); - if (d_equalityEngine.areEqual(a, b)) { + Assert(d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)); + if (d_equalityEngine->areEqual(a, b)) + { // The terms are implied to be equal return EQUALITY_TRUE; } - if (d_equalityEngine.areDisequal(a, b, false)) { + if (d_equalityEngine->areDisequal(a, b, false)) + { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -1395,15 +1403,20 @@ void TheoryDatatypes::addCarePairs(TNodeTrie* t1, for (unsigned k = 0; k < f1.getNumChildren(); ++ k) { TNode x = f1[k]; TNode y = f2[k]; - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); Assert(!areDisequal(x, y)); Assert(!areCareDisequal(x, y)); - if( !d_equalityEngine.areEqual( x, y ) ){ + if (!d_equalityEngine->areEqual(x, y)) + { Trace("dt-cg") << "Arg #" << k << " is " << x << " " << y << std::endl; - if( d_equalityEngine.isTriggerTerm(x, THEORY_DATATYPES) && d_equalityEngine.isTriggerTerm(y, THEORY_DATATYPES) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_DATATYPES); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_DATATYPES); + if (d_equalityEngine->isTriggerTerm(x, THEORY_DATATYPES) + && d_equalityEngine->isTriggerTerm(y, THEORY_DATATYPES)) + { + TNode x_shared = d_equalityEngine->getTriggerTermRepresentative( + x, THEORY_DATATYPES); + TNode y_shared = d_equalityEngine->getTriggerTermRepresentative( + y, THEORY_DATATYPES); currentPairs.push_back(make_pair(x_shared, y_shared)); } } @@ -1432,7 +1445,8 @@ void TheoryDatatypes::addCarePairs(TNodeTrie* t1, std::map<TNode, TNodeTrie>::iterator it2 = it; ++it2; for( ; it2 != t1->d_data.end(); ++it2 ){ - if( !d_equalityEngine.areDisequal(it->first, it2->first, false) ){ + if (!d_equalityEngine->areDisequal(it->first, it2->first, false)) + { if( !areCareDisequal(it->first, it2->first) ){ addCarePairs( &it->second, &it2->second, arity, depth+1, n_pairs ); } @@ -1445,7 +1459,7 @@ void TheoryDatatypes::addCarePairs(TNodeTrie* t1, { for (std::pair<const TNode, TNodeTrie>& tt2 : t2->d_data) { - if (!d_equalityEngine.areDisequal(tt1.first, tt2.first, false)) + if (!d_equalityEngine->areDisequal(tt1.first, tt2.first, false)) { if (!areCareDisequal(tt1.first, tt2.first)) { @@ -1468,7 +1482,7 @@ void TheoryDatatypes::computeCareGraph(){ unsigned functionTerms = d_functionTerms.size(); for( unsigned i=0; i<functionTerms; i++ ){ TNode f1 = d_functionTerms[i]; - Assert(d_equalityEngine.hasTerm(f1)); + Assert(d_equalityEngine->hasTerm(f1)); Trace("dt-cg-debug") << "...build for " << f1 << std::endl; //break into index based on operator, and type of first argument (since some operators are parametric) Node op = f1.getOperator(); @@ -1476,8 +1490,9 @@ void TheoryDatatypes::computeCareGraph(){ std::vector< TNode > reps; bool has_trigger_arg = false; for( unsigned j=0; j<f1.getNumChildren(); j++ ){ - reps.push_back( d_equalityEngine.getRepresentative( f1[j] ) ); - if( d_equalityEngine.isTriggerTerm( f1[j], THEORY_DATATYPES ) ){ + reps.push_back(d_equalityEngine->getRepresentative(f1[j])); + if (d_equalityEngine->isTriggerTerm(f1[j], THEORY_DATATYPES)) + { has_trigger_arg = true; } } @@ -1502,7 +1517,8 @@ void TheoryDatatypes::computeCareGraph(){ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) { - Trace("dt-cmi") << "Datatypes : Collect model info " << d_equalityEngine.consistent() << std::endl; + Trace("dt-cmi") << "Datatypes : Collect model info " + << d_equalityEngine->consistent() << std::endl; Trace("dt-model") << std::endl; printModelDebug( "dt-model" ); Trace("dt-model") << std::endl; @@ -1513,13 +1529,13 @@ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) getRelevantTerms(termSet); //combine the equality engine - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { return false; } //get all constructors - eq::EqClassesIterator eqccs_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqccs_i = eq::EqClassesIterator(d_equalityEngine); std::vector< Node > cons; std::vector< Node > nodes; std::map< Node, Node > eqc_cons; @@ -1558,7 +1574,8 @@ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) bool addCons = false; TypeNode tt = eqc.getType(); const DType& dt = tt.getDType(); - if( !d_equalityEngine.hasTerm( eqc ) ){ + if (!d_equalityEngine->hasTerm(eqc)) + { Assert(false); }else{ Trace("dt-cmi") << "NOTICE : Datatypes: no constructor in equivalence class " << eqc << std::endl; @@ -1578,12 +1595,6 @@ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) bool cfinite = dt[ i ].isInterpretedFinite( tt ); if( pcons[i] && (r==1)==cfinite ){ neqc = utils::getInstCons(eqc, dt, i); - //for( unsigned j=0; j<neqc.getNumChildren(); j++ ){ - // //if( sels[i].find( j )==sels[i].end() && neqc[j].getType().isDatatype() ){ - // if( !d_equalityEngine.hasTerm( neqc[j] ) && neqc[j].getType().isDatatype() ){ - // nodes.push_back( neqc[j] ); - // } - //} break; } } @@ -1773,7 +1784,7 @@ Node TheoryDatatypes::getInstantiateCons(Node n, const DType& dt, int index) //Assert( n_ic==Rewriter::rewrite( n_ic ) ); n_ic = Rewriter::rewrite( n_ic ); collectTerms( n_ic ); - d_equalityEngine.addTerm(n_ic); + d_equalityEngine->addTerm(n_ic); Debug("dt-enum") << "Made instantiate cons " << n_ic << std::endl; } d_inst_map[n][index] = n_ic; @@ -1824,7 +1835,7 @@ void TheoryDatatypes::instantiate( EqcInfo* eqc, Node n ){ void TheoryDatatypes::checkCycles() { Trace("datatypes-cycle-check") << "Check acyclicity" << std::endl; std::vector< Node > cdt_eqc; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while( !eqcs_i.isFinished() ){ Node eqc = (*eqcs_i); TypeNode tn = eqc.getType(); @@ -2115,15 +2126,13 @@ bool TheoryDatatypes::mustCommunicateFact( Node n, Node exp ){ } } -bool TheoryDatatypes::hasTerm( TNode a ){ - return d_equalityEngine.hasTerm( a ); -} +bool TheoryDatatypes::hasTerm(TNode a) { return d_equalityEngine->hasTerm(a); } bool TheoryDatatypes::areEqual( TNode a, TNode b ){ if( a==b ){ return true; }else if( hasTerm( a ) && hasTerm( b ) ){ - return d_equalityEngine.areEqual( a, b ); + return d_equalityEngine->areEqual(a, b); }else{ return false; } @@ -2133,7 +2142,7 @@ bool TheoryDatatypes::areDisequal( TNode a, TNode b ){ if( a==b ){ return false; }else if( hasTerm( a ) && hasTerm( b ) ){ - return d_equalityEngine.areDisequal( a, b, false ); + return d_equalityEngine->areDisequal(a, b, false); }else{ //TODO : constants here? return false; @@ -2141,11 +2150,15 @@ bool TheoryDatatypes::areDisequal( TNode a, TNode b ){ } bool TheoryDatatypes::areCareDisequal( TNode x, TNode y ) { - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); - if( d_equalityEngine.isTriggerTerm(x, THEORY_DATATYPES) && d_equalityEngine.isTriggerTerm(y, THEORY_DATATYPES) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_DATATYPES); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_DATATYPES); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); + if (d_equalityEngine->isTriggerTerm(x, THEORY_DATATYPES) + && d_equalityEngine->isTriggerTerm(y, THEORY_DATATYPES)) + { + TNode x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_DATATYPES); + TNode y_shared = + d_equalityEngine->getTriggerTermRepresentative(y, THEORY_DATATYPES); EqualityStatus eqStatus = d_valuation.getEqualityStatus(x_shared, y_shared); if( eqStatus==EQUALITY_FALSE_AND_PROPAGATED || eqStatus==EQUALITY_FALSE || eqStatus==EQUALITY_FALSE_IN_MODEL ){ return true; @@ -2156,7 +2169,7 @@ bool TheoryDatatypes::areCareDisequal( TNode x, TNode y ) { TNode TheoryDatatypes::getRepresentative( TNode a ){ if( hasTerm( a ) ){ - return d_equalityEngine.getRepresentative( a ); + return d_equalityEngine->getRepresentative(a); }else{ return a; } @@ -2172,7 +2185,7 @@ void TheoryDatatypes::printModelDebug( const char* c ){ } Trace( c ) << "Datatypes model : " << std::endl; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while( !eqcs_i.isFinished() ){ Node eqc = (*eqcs_i); //if( !eqc.getType().isBoolean() ){ @@ -2182,7 +2195,7 @@ void TheoryDatatypes::printModelDebug( const char* c ){ Trace( c ) << eqc << " : " << eqc.getType() << " : " << std::endl; Trace( c ) << " { "; //add terms to model - eq::EqClassIterator eqc_i = eq::EqClassIterator( eqc, &d_equalityEngine ); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); while( !eqc_i.isFinished() ){ if( (*eqc_i)!=eqc ){ Trace( c ) << (*eqc_i) << " "; @@ -2248,7 +2261,7 @@ void TheoryDatatypes::getRelevantTerms( std::set<Node>& termSet ) { << std::endl; //also include non-singleton equivalence classes TODO : revisit this - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while( !eqcs_i.isFinished() ){ TNode r = (*eqcs_i); bool addedFirst = false; @@ -2256,7 +2269,7 @@ void TheoryDatatypes::getRelevantTerms( std::set<Node>& termSet ) { TypeNode rtn = r.getType(); if (!rtn.isBoolean()) { - eq::EqClassIterator eqc_i = eq::EqClassIterator(r, &d_equalityEngine); + eq::EqClassIterator eqc_i = eq::EqClassIterator(r, d_equalityEngine); while (!eqc_i.isFinished()) { TNode n = (*eqc_i); @@ -2296,7 +2309,7 @@ std::pair<bool, Node> TheoryDatatypes::entailmentCheck(TNode lit) if( atom.getKind()==APPLY_TESTER ){ Node n = atom[0]; if( hasTerm( n ) ){ - Node r = d_equalityEngine.getRepresentative( n ); + Node r = d_equalityEngine->getRepresentative(n); EqcInfo * ei = getOrMakeEqcInfo( r, false ); int l_index = getLabelIndex( ei, r ); int t_index = static_cast<int>(utils::indexOf(atom.getOperator())); diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index 422a01f07..a68caca94 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -145,8 +145,6 @@ private: private: /** The notify class */ NotifyClass d_notify; - /** Equaltity engine */ - eq::EqualityEngine d_equalityEngine; /** information necessary for equivalence classes */ std::map< Node, EqcInfo* > d_eqc_info; /** map from nodes to their instantiated equivalent for each constructor type */ @@ -269,9 +267,18 @@ private: ProofNodeManager* pnm = nullptr); ~TheoryDatatypes(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ + void finishInit() override; + //--------------------------------- end initialization /** propagate */ void propagate(Effort effort) override; @@ -295,7 +302,6 @@ private: void check(Effort e) override; bool needsCheckLastEffort() override; void preRegisterTerm(TNode n) override; - void finishInit() override; TrustNode expandDefinition(Node n) override; TrustNode ppRewrite(TNode n) override; void presolve() override; @@ -307,8 +313,6 @@ private: { return std::string("TheoryDatatypes"); } - /** equality engine */ - eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } bool getCurrentSubstitution(int effort, std::vector<Node>& vars, std::vector<Node>& subs, diff --git a/src/theory/ee_manager_distributed.cpp b/src/theory/ee_manager_distributed.cpp index 21237816f..eb12ce893 100644 --- a/src/theory/ee_manager_distributed.cpp +++ b/src/theory/ee_manager_distributed.cpp @@ -61,6 +61,7 @@ void EqEngineManagerDistributed::finishInit() } // allocate the equality engine eet.d_allocEe.reset(allocateEqualityEngine(esi, c)); + eet.d_usedEe = eet.d_allocEe.get(); } const LogicInfo& logicInfo = d_te.getLogicInfo(); diff --git a/src/theory/ee_manager_distributed.h b/src/theory/ee_manager_distributed.h index 3de1898d7..8cac225be 100644 --- a/src/theory/ee_manager_distributed.h +++ b/src/theory/ee_manager_distributed.h @@ -41,6 +41,9 @@ namespace theory { */ struct EeTheoryInfo { + EeTheoryInfo() : d_usedEe(nullptr) {} + /** The equality engine that the theory uses (if it exists) */ + eq::EqualityEngine* d_usedEe; /** The equality engine allocated by this theory (if it exists) */ std::unique_ptr<eq::EqualityEngine> d_allocEe; }; diff --git a/src/theory/fp/theory_fp.cpp b/src/theory/fp/theory_fp.cpp index a4cff8c95..f5cc16ea9 100644 --- a/src/theory/fp/theory_fp.cpp +++ b/src/theory/fp/theory_fp.cpp @@ -107,7 +107,6 @@ TheoryFp::TheoryFp(context::Context* c, ProofNodeManager* pnm) : Theory(THEORY_FP, c, u, out, valuation, logicInfo, pnm), d_notification(*this), - d_equalityEngine(d_notification, c, "theory::fp::ee", true), d_registeredTerms(u), d_conv(u), d_expansionRequested(false), @@ -122,60 +121,74 @@ TheoryFp::TheoryFp(context::Context* c, floatToRealMap(u), abstractionMap(u) { +} /* TheoryFp::TheoryFp() */ + +TheoryRewriter* TheoryFp::getTheoryRewriter() { return &d_rewriter; } + +bool TheoryFp::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notification; + esi.d_name = "theory::fp::ee"; + return true; +} + +void TheoryFp::finishInit() +{ + Assert(d_equalityEngine != nullptr); // Kinds that are to be handled in the congruence closure - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ABS); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_NEG); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_PLUS); - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_SUB); // Removed - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_MULT); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_DIV); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_FMA); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_SQRT); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_REM); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_RTI); - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_MIN); // Removed - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_MAX); // Removed - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_MIN_TOTAL); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_MAX_TOTAL); - - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_EQ); // Removed - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_LEQ); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_LT); - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_GEQ); // Removed - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_GT); // Removed - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISN); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISSN); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISZ); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISINF); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISNAN); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISNEG); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISPOS); - - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_FP_REAL); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR); - d_equalityEngine.addFunctionKind( + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ABS); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_NEG); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_PLUS); + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_SUB); // Removed + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_MULT); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_DIV); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_FMA); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_SQRT); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_REM); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_RTI); + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_MIN); // Removed + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_MAX); // Removed + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_MIN_TOTAL); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_MAX_TOTAL); + + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_EQ); // Removed + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_LEQ); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_LT); + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_GEQ); // Removed + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_GT); // Removed + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISN); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISSN); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISZ); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISINF); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISNAN); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISNEG); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISPOS); + + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_FP_REAL); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR); + d_equalityEngine->addFunctionKind( kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR); - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_FP_GENERIC); // + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_FP_GENERIC); // // Needed in parsing, should be rewritten away - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_UBV); // Removed - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_SBV); // Removed - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_REAL); // Removed - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_UBV_TOTAL); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_SBV_TOTAL); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_REAL_TOTAL); - - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_NAN); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_INF); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_ZERO); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_SIGN); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_EXPONENT); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND); - d_equalityEngine.addFunctionKind(kind::ROUNDINGMODE_BITBLAST); -} /* TheoryFp::TheoryFp() */ + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_UBV); // Removed + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_SBV); // Removed + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_REAL); // Removed + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_UBV_TOTAL); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_SBV_TOTAL); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_REAL_TOTAL); + + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_NAN); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_INF); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_ZERO); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_SIGN); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_EXPONENT); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND); + d_equalityEngine->addFunctionKind(kind::ROUNDINGMODE_BITBLAST); +} Node TheoryFp::minUF(Node node) { Assert(node.getKind() == kind::FLOATINGPOINT_MIN); @@ -803,11 +816,11 @@ void TheoryFp::registerTerm(TNode node) { // Add to the equality engine if (k == kind::EQUAL) { - d_equalityEngine.addTriggerEquality(node); + d_equalityEngine->addTriggerEquality(node); } else { - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); } // Give the expansion of classifications in terms of equalities @@ -961,22 +974,22 @@ void TheoryFp::check(Effort level) { if (negated) { Debug("fp-eq") << "TheoryFp::check(): adding dis-equality " << fact[0] << std::endl; - d_equalityEngine.assertEquality(predicate, false, fact); - + d_equalityEngine->assertEquality(predicate, false, fact); } else { Debug("fp-eq") << "TheoryFp::check(): adding equality " << fact << std::endl; - d_equalityEngine.assertEquality(predicate, true, fact); + d_equalityEngine->assertEquality(predicate, true, fact); } } else { // A system-wide invariant; predicates are registered before they are // asserted Assert(isRegistered(predicate)); - if (d_equalityEngine.isFunctionKind(predicate.getKind())) { + if (d_equalityEngine->isFunctionKind(predicate.getKind())) + { Debug("fp-eq") << "TheoryFp::check(): adding predicate " << predicate << " is " << !negated << std::endl; - d_equalityEngine.assertPredicate(predicate, !negated, fact); + d_equalityEngine->assertPredicate(predicate, !negated, fact); } } } @@ -1007,10 +1020,6 @@ void TheoryFp::check(Effort level) { } /* TheoryFp::check() */ -void TheoryFp::setMasterEqualityEngine(eq::EqualityEngine *eq) { - d_equalityEngine.setMasterEqualityEngine(eq); -} - TrustNode TheoryFp::explain(TNode n) { Trace("fp") << "TheoryFp::explain(): explain " << n << std::endl; @@ -1022,9 +1031,9 @@ TrustNode TheoryFp::explain(TNode n) bool polarity = n.getKind() != kind::NOT; TNode atom = polarity ? n : n[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions); } else { - d_equalityEngine.explainPredicate(atom, polarity, assumptions); + d_equalityEngine->explainPredicate(atom, polarity, assumptions); } Node exp = helper::buildConjunct(assumptions); @@ -1177,7 +1186,7 @@ void TheoryFp::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) { << " = " << t2 << std::endl; std::vector<TNode> assumptions; - d_theorySolver.d_equalityEngine.explainEquality(t1, t2, true, assumptions); + d_theorySolver.d_equalityEngine->explainEquality(t1, t2, true, assumptions); Node conflict = helper::buildConjunct(assumptions); diff --git a/src/theory/fp/theory_fp.h b/src/theory/fp/theory_fp.h index a1dd8a731..02e7e4232 100644 --- a/src/theory/fp/theory_fp.h +++ b/src/theory/fp/theory_fp.h @@ -42,8 +42,18 @@ class TheoryFp : public Theory { Valuation valuation, const LogicInfo& logicInfo, ProofNodeManager* pnm = nullptr); - - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ + void finishInit() override; + //--------------------------------- end initialization TrustNode expandDefinition(Node node) override; @@ -60,8 +70,6 @@ class TheoryFp : public Theory { std::string identify() const override { return "THEORY_FP"; } - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - TrustNode explain(TNode n) override; protected: @@ -86,7 +94,6 @@ class TheoryFp : public Theory { friend NotifyClass; NotifyClass d_notification; - eq::EqualityEngine d_equalityEngine; /** General utility **/ void registerTerm(TNode node); diff --git a/src/theory/quantifiers/theory_quantifiers.cpp b/src/theory/quantifiers/theory_quantifiers.cpp index 1475446fe..04e83032b 100644 --- a/src/theory/quantifiers/theory_quantifiers.cpp +++ b/src/theory/quantifiers/theory_quantifiers.cpp @@ -64,6 +64,7 @@ TheoryQuantifiers::TheoryQuantifiers(Context* c, TheoryQuantifiers::~TheoryQuantifiers() { } +TheoryRewriter* TheoryQuantifiers::getTheoryRewriter() { return &d_rewriter; } void TheoryQuantifiers::finishInit() { // quantifiers are not evaluated in getModelValue diff --git a/src/theory/quantifiers/theory_quantifiers.h b/src/theory/quantifiers/theory_quantifiers.h index 3168af195..c378f3537 100644 --- a/src/theory/quantifiers/theory_quantifiers.h +++ b/src/theory/quantifiers/theory_quantifiers.h @@ -42,10 +42,13 @@ class TheoryQuantifiers : public Theory { ProofNodeManager* pnm = nullptr); ~TheoryQuantifiers(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; /** finish initialization */ void finishInit() override; + //--------------------------------- end initialization + void preRegisterTerm(TNode n) override; void presolve() override; void ppNotifyAssertions(const std::vector<Node>& assertions) override; diff --git a/src/theory/sep/theory_sep.cpp b/src/theory/sep/theory_sep.cpp index 4dfdb9fa5..edb5dd0ae 100644 --- a/src/theory/sep/theory_sep.cpp +++ b/src/theory/sep/theory_sep.cpp @@ -48,7 +48,6 @@ TheorySep::TheorySep(context::Context* c, : Theory(THEORY_SEP, c, u, out, valuation, logicInfo, pnm), d_lemmas_produced_c(u), d_notify(*this), - d_equalityEngine(d_notify, c, "theory::sep::ee", true), d_conflict(c, false), d_reduce(u), d_infer(c), @@ -58,10 +57,6 @@ TheorySep::TheorySep(context::Context* c, d_true = NodeManager::currentNM()->mkConst<bool>(true); d_false = NodeManager::currentNM()->mkConst<bool>(false); d_bounds_init = false; - - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::SEP_PTO); - //d_equalityEngine.addFunctionKind(kind::SEP_STAR); } TheorySep::~TheorySep() { @@ -70,8 +65,21 @@ TheorySep::~TheorySep() { } } -void TheorySep::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); +TheoryRewriter* TheorySep::getTheoryRewriter() { return &d_rewriter; } + +bool TheorySep::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::sep::ee"; + return true; +} + +void TheorySep::finishInit() +{ + Assert(d_equalityEngine != nullptr); + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::SEP_PTO); + // we could but don't do congruence on SEP_STAR here. } Node TheorySep::mkAnd( std::vector< TNode >& assumptions ) { @@ -126,9 +134,10 @@ void TheorySep::explain(TNode literal, std::vector<TNode>& assumptions) { bool polarity = literal.getKind() != kind::NOT; TNode atom = polarity ? literal : literal[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality( atom[0], atom[1], polarity, assumptions, NULL ); + d_equalityEngine->explainEquality( + atom[0], atom[1], polarity, assumptions, NULL); } else { - d_equalityEngine.explainPredicate( atom, polarity, assumptions ); + d_equalityEngine->explainPredicate(atom, polarity, assumptions); } } } @@ -155,17 +164,19 @@ TrustNode TheorySep::explain(TNode literal) void TheorySep::addSharedTerm(TNode t) { Debug("sep") << "TheorySep::addSharedTerm(" << t << ")" << std::endl; - d_equalityEngine.addTriggerTerm(t, THEORY_SEP); + d_equalityEngine->addTriggerTerm(t, THEORY_SEP); } EqualityStatus TheorySep::getEqualityStatus(TNode a, TNode b) { - Assert(d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)); - if (d_equalityEngine.areEqual(a, b)) { + Assert(d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)); + if (d_equalityEngine->areEqual(a, b)) + { // The terms are implied to be equal return EQUALITY_TRUE; } - else if (d_equalityEngine.areDisequal(a, b, false)) { + else if (d_equalityEngine->areDisequal(a, b, false)) + { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -211,7 +222,7 @@ bool TheorySep::collectModelInfo(TheoryModel* m) computeRelevantTerms(termSet); // Send the equality engine information to the model - return m->assertEqualityEngine(&d_equalityEngine, &termSet); + return m->assertEqualityEngine(d_equalityEngine, &termSet); } void TheorySep::postProcessModel( TheoryModel* m ){ @@ -490,16 +501,16 @@ void TheorySep::check(Effort e) { if( !is_spatial ){ Trace("sep-assert") << "Asserting " << atom << ", pol = " << polarity << " to EE..." << std::endl; if( s_atom.getKind()==kind::EQUAL ){ - d_equalityEngine.assertEquality(atom, polarity, fact); + d_equalityEngine->assertEquality(atom, polarity, fact); }else{ - d_equalityEngine.assertPredicate(atom, polarity, fact); + d_equalityEngine->assertPredicate(atom, polarity, fact); } Trace("sep-assert") << "Done asserting " << atom << " to EE." << std::endl; }else if( s_atom.getKind()==kind::SEP_PTO ){ Node pto_lbl = NodeManager::currentNM()->mkNode( kind::SINGLETON, s_atom[0] ); Assert(s_lbl == pto_lbl); Trace("sep-assert") << "Asserting " << s_atom << std::endl; - d_equalityEngine.assertPredicate(s_atom, polarity, fact); + d_equalityEngine->assertPredicate(s_atom, polarity, fact); //associate the equivalence class of the lhs with this pto Node r = getRepresentative( s_lbl ); HeapAssertInfo * ei = getOrMakeEqcInfo( r, true ); @@ -619,11 +630,11 @@ void TheorySep::check(Effort e) { Trace("sep-process") << "---" << std::endl; } if(Trace.isOn("sep-eqc")) { - eq::EqClassesIterator eqcs2_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs2_i = eq::EqClassesIterator(d_equalityEngine); Trace("sep-eqc") << "EQC:" << std::endl; while( !eqcs2_i.isFinished() ){ Node eqc = (*eqcs2_i); - eq::EqClassIterator eqc2_i = eq::EqClassIterator( eqc, &d_equalityEngine ); + eq::EqClassIterator eqc2_i = eq::EqClassIterator(eqc, d_equalityEngine); Trace("sep-eqc") << "Eqc( " << eqc << " ) : { "; while( !eqc2_i.isFinished() ) { if( (*eqc2_i)!=eqc ){ @@ -1552,22 +1563,21 @@ void TheorySep::computeLabelModel( Node lbl ) { } Node TheorySep::getRepresentative( Node t ) { - if( d_equalityEngine.hasTerm( t ) ){ - return d_equalityEngine.getRepresentative( t ); + if (d_equalityEngine->hasTerm(t)) + { + return d_equalityEngine->getRepresentative(t); }else{ return t; } } -bool TheorySep::hasTerm( Node a ){ - return d_equalityEngine.hasTerm( a ); -} +bool TheorySep::hasTerm(Node a) { return d_equalityEngine->hasTerm(a); } bool TheorySep::areEqual( Node a, Node b ){ if( a==b ){ return true; }else if( hasTerm( a ) && hasTerm( b ) ){ - return d_equalityEngine.areEqual( a, b ); + return d_equalityEngine->areEqual(a, b); }else{ return false; } @@ -1577,7 +1587,8 @@ bool TheorySep::areDisequal( Node a, Node b ){ if( a==b ){ return false; }else if( hasTerm( a ) && hasTerm( b ) ){ - if( d_equalityEngine.areDisequal( a, b, false ) ){ + if (d_equalityEngine->areDisequal(a, b, false)) + { return true; } } @@ -1743,9 +1754,9 @@ void TheorySep::doPendingFacts() { bool pol = d_pending[i].getKind()!=kind::NOT; Trace("sep-pending") << "Sep : Assert to EE : " << atom << ", pol = " << pol << std::endl; if( atom.getKind()==kind::EQUAL ){ - d_equalityEngine.assertEquality(atom, pol, d_pending_exp[i]); + d_equalityEngine->assertEquality(atom, pol, d_pending_exp[i]); }else{ - d_equalityEngine.assertPredicate(atom, pol, d_pending_exp[i]); + d_equalityEngine->assertPredicate(atom, pol, d_pending_exp[i]); } } }else{ diff --git a/src/theory/sep/theory_sep.h b/src/theory/sep/theory_sep.h index 7c6ce38c4..84a7025f0 100644 --- a/src/theory/sep/theory_sep.h +++ b/src/theory/sep/theory_sep.h @@ -74,9 +74,18 @@ class TheorySep : public Theory { ProofNodeManager* pnm = nullptr); ~TheorySep(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ + void finishInit() override; + //--------------------------------- end initialization std::string identify() const override { return std::string("TheorySep"); } @@ -202,9 +211,6 @@ class TheorySep : public Theory { /** The notify class for d_equalityEngine */ NotifyClass d_notify; - /** Equaltity engine */ - eq::EqualityEngine d_equalityEngine; - /** Are we in conflict? */ context::CDO<bool> d_conflict; std::vector< Node > d_pending_exp; @@ -326,7 +332,6 @@ class TheorySep : public Theory { void doPendingFacts(); public: - eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } void initializeBounds(); };/* class TheorySep */ diff --git a/src/theory/sets/theory_sets.cpp b/src/theory/sets/theory_sets.cpp index bf81099a7..fd9af488f 100644 --- a/src/theory/sets/theory_sets.cpp +++ b/src/theory/sets/theory_sets.cpp @@ -35,8 +35,7 @@ TheorySets::TheorySets(context::Context* c, ProofNodeManager* pnm) : Theory(THEORY_SETS, c, u, out, valuation, logicInfo, pnm), d_internal(new TheorySetsPrivate(*this, c, u)), - d_notify(*d_internal.get()), - d_equalityEngine(d_notify, c, "theory::sets::ee", true) + d_notify(*d_internal.get()) { // Do not move me to the header. // The constructor + destructor are not in the header as d_internal is a @@ -54,29 +53,38 @@ TheoryRewriter* TheorySets::getTheoryRewriter() return d_internal->getTheoryRewriter(); } +bool TheorySets::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::sets::ee"; + return true; +} + void TheorySets::finishInit() { + Assert(d_equalityEngine != nullptr); + d_valuation.setUnevaluatedKind(COMPREHENSION); // choice is used to eliminate witness d_valuation.setUnevaluatedKind(WITNESS); // functions we are doing congruence over - d_equalityEngine.addFunctionKind(kind::SINGLETON); - d_equalityEngine.addFunctionKind(kind::UNION); - d_equalityEngine.addFunctionKind(kind::INTERSECTION); - d_equalityEngine.addFunctionKind(kind::SETMINUS); - d_equalityEngine.addFunctionKind(kind::MEMBER); - d_equalityEngine.addFunctionKind(kind::SUBSET); + d_equalityEngine->addFunctionKind(kind::SINGLETON); + d_equalityEngine->addFunctionKind(kind::UNION); + d_equalityEngine->addFunctionKind(kind::INTERSECTION); + d_equalityEngine->addFunctionKind(kind::SETMINUS); + d_equalityEngine->addFunctionKind(kind::MEMBER); + d_equalityEngine->addFunctionKind(kind::SUBSET); // relation operators - d_equalityEngine.addFunctionKind(PRODUCT); - d_equalityEngine.addFunctionKind(JOIN); - d_equalityEngine.addFunctionKind(TRANSPOSE); - d_equalityEngine.addFunctionKind(TCLOSURE); - d_equalityEngine.addFunctionKind(JOIN_IMAGE); - d_equalityEngine.addFunctionKind(IDEN); - d_equalityEngine.addFunctionKind(APPLY_CONSTRUCTOR); + d_equalityEngine->addFunctionKind(PRODUCT); + d_equalityEngine->addFunctionKind(JOIN); + d_equalityEngine->addFunctionKind(TRANSPOSE); + d_equalityEngine->addFunctionKind(TCLOSURE); + d_equalityEngine->addFunctionKind(JOIN_IMAGE); + d_equalityEngine->addFunctionKind(IDEN); + d_equalityEngine->addFunctionKind(APPLY_CONSTRUCTOR); // we do congruence over cardinality - d_equalityEngine.addFunctionKind(CARD); + d_equalityEngine->addFunctionKind(CARD); // finish initialization internally d_internal->finishInit(); @@ -198,16 +206,6 @@ bool TheorySets::isEntailed( Node n, bool pol ) { return d_internal->isEntailed( n, pol ); } -eq::EqualityEngine* TheorySets::getEqualityEngine() -{ - return &d_equalityEngine; -} - -void TheorySets::setMasterEqualityEngine(eq::EqualityEngine* eq) -{ - d_equalityEngine.setMasterEqualityEngine(eq); -} - /**************************** eq::NotifyClass *****************************/ bool TheorySets::NotifyClass::eqNotifyTriggerEquality(TNode equality, diff --git a/src/theory/sets/theory_sets.h b/src/theory/sets/theory_sets.h index 84291346b..cb8fdfbc3 100644 --- a/src/theory/sets/theory_sets.h +++ b/src/theory/sets/theory_sets.h @@ -48,6 +48,12 @@ class TheorySets : public Theory //--------------------------------- initialization /** get the official theory rewriter of this theory */ TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; /** finish initialization */ void finishInit() override; //--------------------------------- end initialization @@ -65,10 +71,7 @@ class TheorySets : public Theory PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; void presolve() override; void propagate(Effort) override; - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; bool isEntailed(Node n, bool pol); - /* equality engine */ - virtual eq::EqualityEngine* getEqualityEngine() override; private: /** Functions to handle callbacks from equality engine */ class NotifyClass : public eq::EqualityEngineNotify @@ -92,9 +95,7 @@ class TheorySets : public Theory /** The internal theory */ std::unique_ptr<TheorySetsPrivate> d_internal; /** Instance of the above class */ - NotifyClass d_notify; - /** Equality engine */ - eq::EqualityEngine d_equalityEngine; + NotifyClass d_notify; }; /* class TheorySets */ }/* CVC4::theory::sets namespace */ diff --git a/src/theory/strings/solver_state.cpp b/src/theory/strings/solver_state.cpp index a554ac595..8634478fd 100644 --- a/src/theory/strings/solver_state.cpp +++ b/src/theory/strings/solver_state.cpp @@ -27,11 +27,10 @@ namespace strings { SolverState::SolverState(context::Context* c, context::UserContext* u, - eq::EqualityEngine& ee, Valuation& v) : d_context(c), d_ucontext(u), - d_ee(ee), + d_ee(nullptr), d_eeDisequalities(c), d_valuation(v), d_conflict(c, false), @@ -48,19 +47,25 @@ SolverState::~SolverState() } } +void SolverState::finishInit(eq::EqualityEngine* ee) +{ + Assert(ee != nullptr); + d_ee = ee; +} + context::Context* SolverState::getSatContext() const { return d_context; } context::UserContext* SolverState::getUserContext() const { return d_ucontext; } Node SolverState::getRepresentative(Node t) const { - if (d_ee.hasTerm(t)) + if (d_ee->hasTerm(t)) { - return d_ee.getRepresentative(t); + return d_ee->getRepresentative(t); } return t; } -bool SolverState::hasTerm(Node a) const { return d_ee.hasTerm(a); } +bool SolverState::hasTerm(Node a) const { return d_ee->hasTerm(a); } bool SolverState::areEqual(Node a, Node b) const { @@ -70,7 +75,7 @@ bool SolverState::areEqual(Node a, Node b) const } else if (hasTerm(a) && hasTerm(b)) { - return d_ee.areEqual(a, b); + return d_ee->areEqual(a, b); } return false; } @@ -83,17 +88,17 @@ bool SolverState::areDisequal(Node a, Node b) const } else if (hasTerm(a) && hasTerm(b)) { - Node ar = d_ee.getRepresentative(a); - Node br = d_ee.getRepresentative(b); + Node ar = d_ee->getRepresentative(a); + Node br = d_ee->getRepresentative(b); return (ar != br && ar.isConst() && br.isConst()) - || d_ee.areDisequal(ar, br, false); + || d_ee->areDisequal(ar, br, false); } Node ar = getRepresentative(a); Node br = getRepresentative(b); return ar != br && ar.isConst() && br.isConst(); } -eq::EqualityEngine* SolverState::getEqualityEngine() const { return &d_ee; } +eq::EqualityEngine* SolverState::getEqualityEngine() const { return d_ee; } const context::CDList<Node>& SolverState::getDisequalityList() const { @@ -105,7 +110,7 @@ void SolverState::eqNotifyNewClass(TNode t) Kind k = t.getKind(); if (k == STRING_LENGTH || k == STRING_TO_CODE) { - Node r = d_ee.getRepresentative(t[0]); + Node r = d_ee->getRepresentative(t[0]); EqcInfo* ei = getOrMakeEqcInfo(r); if (k == STRING_LENGTH) { @@ -317,14 +322,14 @@ void SolverState::separateByLength( NodeManager* nm = NodeManager::currentNM(); for (const Node& eqc : n) { - Assert(d_ee.getRepresentative(eqc) == eqc); + Assert(d_ee->getRepresentative(eqc) == eqc); TypeNode tnEqc = eqc.getType(); EqcInfo* ei = getOrMakeEqcInfo(eqc, false); Node lt = ei ? ei->d_lengthTerm : Node::null(); if (!lt.isNull()) { lt = nm->mkNode(STRING_LENGTH, lt); - Node r = d_ee.getRepresentative(lt); + Node r = d_ee->getRepresentative(lt); std::pair<Node, TypeNode> lkey(r, tnEqc); if (eqc_to_leqc.find(lkey) == eqc_to_leqc.end()) { diff --git a/src/theory/strings/solver_state.h b/src/theory/strings/solver_state.h index 8d3162b38..0322abdb7 100644 --- a/src/theory/strings/solver_state.h +++ b/src/theory/strings/solver_state.h @@ -46,9 +46,13 @@ class SolverState public: SolverState(context::Context* c, context::UserContext* u, - eq::EqualityEngine& ee, Valuation& v); ~SolverState(); + /** + * Finish initialize, ee is a pointer to the official equality engine + * of theory of strings. + */ + void finishInit(eq::EqualityEngine* ee); /** Get the SAT context */ context::Context* getSatContext() const; /** Get the user context */ @@ -186,8 +190,8 @@ class SolverState context::Context* d_context; /** Pointer to the user context object used by the theory of strings. */ context::UserContext* d_ucontext; - /** Reference to equality engine of the theory of strings. */ - eq::EqualityEngine& d_ee; + /** Pointer to equality engine of the theory of strings. */ + eq::EqualityEngine* d_ee; /** * The (SAT-context-dependent) list of disequalities that have been asserted * to the equality engine above. diff --git a/src/theory/strings/term_registry.cpp b/src/theory/strings/term_registry.cpp index f28db4c35..71b45915f 100644 --- a/src/theory/strings/term_registry.cpp +++ b/src/theory/strings/term_registry.cpp @@ -37,12 +37,10 @@ typedef expr::Attribute<StringsProxyVarAttributeId, bool> StringsProxyVarAttribute; TermRegistry::TermRegistry(SolverState& s, - eq::EqualityEngine& ee, OutputChannel& out, SequencesStatistics& statistics, ProofNodeManager* pnm) : d_state(s), - d_ee(ee), d_out(out), d_statistics(statistics), d_hasStrCode(false), @@ -129,6 +127,7 @@ void TermRegistry::preRegisterTerm(TNode n) { return; } + eq::EqualityEngine* ee = d_state.getEqualityEngine(); d_preregisteredTerms.insert(n); Trace("strings-preregister") << "TheoryString::preregister : " << n << std::endl; @@ -156,15 +155,15 @@ void TermRegistry::preRegisterTerm(TNode n) ss << "Equality between regular expressions is not supported"; throw LogicException(ss.str()); } - d_ee.addTriggerEquality(n); + ee->addTriggerEquality(n); return; } else if (k == STRING_IN_REGEXP) { d_out.requirePhase(n, true); - d_ee.addTriggerPredicate(n); - d_ee.addTerm(n[0]); - d_ee.addTerm(n[1]); + ee->addTriggerPredicate(n); + ee->addTerm(n[0]); + ee->addTerm(n[1]); return; } else if (k == STRING_TO_CODE) @@ -196,17 +195,17 @@ void TermRegistry::preRegisterTerm(TNode n) } } } - d_ee.addTerm(n); + ee->addTerm(n); } else if (tn.isBoolean()) { // Get triggered for both equal and dis-equal - d_ee.addTriggerPredicate(n); + ee->addTriggerPredicate(n); } else { // Function applications/predicates - d_ee.addTerm(n); + ee->addTerm(n); } // Set d_functionsTerms stores all function applications that are // relevant to theory combination. Notice that this is a subset of @@ -216,7 +215,7 @@ void TermRegistry::preRegisterTerm(TNode n) // Concatenation terms do not need to be considered here because // their arguments have string type and do not introduce any shared // terms. - if (n.hasOperator() && d_ee.isFunctionKind(k) && k != STRING_CONCAT) + if (n.hasOperator() && ee->isFunctionKind(k) && k != STRING_CONCAT) { d_functionsTerms.push_back(n); } @@ -313,7 +312,7 @@ void TermRegistry::registerType(TypeNode tn) { // preregister the empty word for the type Node emp = Word::mkEmptyWord(tn); - if (!d_ee.hasTerm(emp)) + if (!d_state.hasTerm(emp)) { preRegisterTerm(emp); } diff --git a/src/theory/strings/term_registry.h b/src/theory/strings/term_registry.h index 2048abec1..45fb40073 100644 --- a/src/theory/strings/term_registry.h +++ b/src/theory/strings/term_registry.h @@ -50,7 +50,6 @@ class TermRegistry public: TermRegistry(SolverState& s, - eq::EqualityEngine& ee, OutputChannel& out, SequencesStatistics& statistics, ProofNodeManager* pnm); @@ -220,8 +219,6 @@ class TermRegistry uint32_t d_cardSize; /** Reference to the solver state of the theory of strings. */ SolverState& d_state; - /** Reference to equality engine of the theory of strings. */ - eq::EqualityEngine& d_ee; /** Reference to the output channel of the theory of strings. */ OutputChannel& d_out; /** Reference to the statistics for the theory of strings/sequences. */ diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index 0ad887d2f..b23765313 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -43,9 +43,8 @@ TheoryStrings::TheoryStrings(context::Context* c, : Theory(THEORY_STRINGS, c, u, out, valuation, logicInfo, pnm), d_notify(*this), d_statistics(), - d_equalityEngine(d_notify, c, "theory::strings::ee", true), - d_state(c, u, d_equalityEngine, d_valuation), - d_termReg(d_state, d_equalityEngine, out, d_statistics, nullptr), + d_state(c, u, d_valuation), + d_termReg(d_state, out, d_statistics, nullptr), d_extTheory(this), d_im(c, u, d_state, d_termReg, d_extTheory, out, d_statistics), d_rewriter(&d_statistics.d_rewrites), @@ -67,30 +66,6 @@ TheoryStrings::TheoryStrings(context::Context* c, d_statistics), d_stringsFmf(c, u, valuation, d_termReg) { - bool eagerEval = options::stringEagerEval(); - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::STRING_LENGTH, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_CONCAT, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_IN_REGEXP, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_TO_CODE, eagerEval); - d_equalityEngine.addFunctionKind(kind::SEQ_UNIT, eagerEval); - - // extended functions - d_equalityEngine.addFunctionKind(kind::STRING_STRCTN, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_LEQ, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_SUBSTR, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_UPDATE, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_ITOS, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_STOI, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_STRIDOF, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_STRREPL, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_STRREPLALL, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_REPLACE_RE, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_REPLACE_RE_ALL, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_STRREPLALL, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_TOLOWER, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_TOUPPER, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_REV, eagerEval); d_zero = NodeManager::currentNM()->mkConst( Rational( 0 ) ); d_one = NodeManager::currentNM()->mkConst( Rational( 1 ) ); @@ -113,26 +88,63 @@ TheoryStrings::~TheoryStrings() { } TheoryRewriter* TheoryStrings::getTheoryRewriter() { return &d_rewriter; } -std::string TheoryStrings::identify() const -{ - return std::string("TheoryStrings"); -} -eq::EqualityEngine* TheoryStrings::getEqualityEngine() + +bool TheoryStrings::needsEqualityEngine(EeSetupInfo& esi) { - return &d_equalityEngine; + esi.d_notify = &d_notify; + esi.d_name = "theory::strings::ee"; + return true; } + void TheoryStrings::finishInit() { + Assert(d_equalityEngine != nullptr); + // witness is used to eliminate str.from_code d_valuation.setUnevaluatedKind(WITNESS); + + bool eagerEval = options::stringEagerEval(); + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::STRING_LENGTH, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_CONCAT, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_IN_REGEXP, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_TO_CODE, eagerEval); + d_equalityEngine->addFunctionKind(kind::SEQ_UNIT, eagerEval); + // extended functions + d_equalityEngine->addFunctionKind(kind::STRING_STRCTN, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_LEQ, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_SUBSTR, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_UPDATE, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_ITOS, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_STOI, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_STRIDOF, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_STRREPL, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_STRREPLALL, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_REPLACE_RE, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_REPLACE_RE_ALL, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_STRREPLALL, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_TOLOWER, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_TOUPPER, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_REV, eagerEval); + + d_state.finishInit(d_equalityEngine); +} + +std::string TheoryStrings::identify() const +{ + return std::string("TheoryStrings"); } bool TheoryStrings::areCareDisequal( TNode x, TNode y ) { - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); - if( d_equalityEngine.isTriggerTerm(x, THEORY_STRINGS) && d_equalityEngine.isTriggerTerm(y, THEORY_STRINGS) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_STRINGS); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_STRINGS); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); + if (d_equalityEngine->isTriggerTerm(x, THEORY_STRINGS) + && d_equalityEngine->isTriggerTerm(y, THEORY_STRINGS)) + { + TNode x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_STRINGS); + TNode y_shared = + d_equalityEngine->getTriggerTermRepresentative(y, THEORY_STRINGS); EqualityStatus eqStatus = d_valuation.getEqualityStatus(x_shared, y_shared); if( eqStatus==EQUALITY_FALSE_AND_PROPAGATED || eqStatus==EQUALITY_FALSE || eqStatus==EQUALITY_FALSE_IN_MODEL ){ return true; @@ -141,14 +153,10 @@ bool TheoryStrings::areCareDisequal( TNode x, TNode y ) { return false; } -void TheoryStrings::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); -} - void TheoryStrings::addSharedTerm(TNode t) { Debug("strings") << "TheoryStrings::addSharedTerm(): " << t << " " << t.getType().isBoolean() << endl; - d_equalityEngine.addTriggerTerm(t, THEORY_STRINGS); + d_equalityEngine->addTriggerTerm(t, THEORY_STRINGS); if (options::stringExp()) { d_esolver.addSharedTerm(t); @@ -157,12 +165,15 @@ void TheoryStrings::addSharedTerm(TNode t) { } EqualityStatus TheoryStrings::getEqualityStatus(TNode a, TNode b) { - if( d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b) ){ - if (d_equalityEngine.areEqual(a, b)) { + if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)) + { + if (d_equalityEngine->areEqual(a, b)) + { // The terms are implied to be equal return EQUALITY_TRUE; } - if (d_equalityEngine.areDisequal(a, b, false)) { + if (d_equalityEngine->areDisequal(a, b, false)) + { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -251,7 +262,7 @@ bool TheoryStrings::collectModelInfo(TheoryModel* m) // Compute terms appearing in assertions and shared terms computeRelevantTerms(termSet); // assert the (relevant) portion of the equality engine to the model - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { Unreachable() << "TheoryStrings::collectModelInfo: failed to assert equality engine" @@ -670,14 +681,15 @@ void TheoryStrings::check(Effort e) { << "Theory of strings " << e << " effort check " << std::endl; if(Trace.isOn("strings-eqc")) { for( unsigned t=0; t<2; t++ ) { - eq::EqClassesIterator eqcs2_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs2_i = eq::EqClassesIterator(d_equalityEngine); Trace("strings-eqc") << (t==0 ? "STRINGS:" : "OTHER:") << std::endl; while( !eqcs2_i.isFinished() ){ Node eqc = (*eqcs2_i); bool print = (t == 0 && eqc.getType().isStringLike()) || (t == 1 && !eqc.getType().isStringLike()); if (print) { - eq::EqClassIterator eqc2_i = eq::EqClassIterator( eqc, &d_equalityEngine ); + eq::EqClassIterator eqc2_i = + eq::EqClassIterator(eqc, d_equalityEngine); Trace("strings-eqc") << "Eqc( " << eqc << " ) : { "; while( !eqc2_i.isFinished() ) { if( (*eqc2_i)!=eqc && (*eqc2_i).getKind()!=kind::EQUAL ){ @@ -779,20 +791,26 @@ void TheoryStrings::addCarePairs(TNodeTrie* t1, if( t2!=NULL ){ Node f1 = t1->getData(); Node f2 = t2->getData(); - if( !d_equalityEngine.areEqual( f1, f2 ) ){ + if (!d_equalityEngine->areEqual(f1, f2)) + { Trace("strings-cg-debug") << "TheoryStrings::computeCareGraph(): checking function " << f1 << " and " << f2 << std::endl; vector< pair<TNode, TNode> > currentPairs; for (unsigned k = 0; k < f1.getNumChildren(); ++ k) { TNode x = f1[k]; TNode y = f2[k]; - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); - Assert(!d_equalityEngine.areDisequal(x, y, false)); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); + Assert(!d_equalityEngine->areDisequal(x, y, false)); Assert(!areCareDisequal(x, y)); - if( !d_equalityEngine.areEqual( x, y ) ){ - if( d_equalityEngine.isTriggerTerm(x, THEORY_STRINGS) && d_equalityEngine.isTriggerTerm(y, THEORY_STRINGS) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_STRINGS); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_STRINGS); + if (!d_equalityEngine->areEqual(x, y)) + { + if (d_equalityEngine->isTriggerTerm(x, THEORY_STRINGS) + && d_equalityEngine->isTriggerTerm(y, THEORY_STRINGS)) + { + TNode x_shared = d_equalityEngine->getTriggerTermRepresentative( + x, THEORY_STRINGS); + TNode y_shared = d_equalityEngine->getTriggerTermRepresentative( + y, THEORY_STRINGS); currentPairs.push_back(make_pair(x_shared, y_shared)); } } @@ -820,7 +838,8 @@ void TheoryStrings::addCarePairs(TNodeTrie* t1, std::map<TNode, TNodeTrie>::iterator it2 = it; ++it2; for( ; it2 != t1->d_data.end(); ++it2 ){ - if( !d_equalityEngine.areDisequal(it->first, it2->first, false) ){ + if (!d_equalityEngine->areDisequal(it->first, it2->first, false)) + { if( !areCareDisequal(it->first, it2->first) ){ addCarePairs( &it->second, &it2->second, arity, depth+1 ); } @@ -833,7 +852,7 @@ void TheoryStrings::addCarePairs(TNodeTrie* t1, { for (std::pair<const TNode, TNodeTrie>& tt2 : t2->d_data) { - if (!d_equalityEngine.areDisequal(tt1.first, tt2.first, false)) + if (!d_equalityEngine->areDisequal(tt1.first, tt2.first, false)) { if (!areCareDisequal(tt1.first, tt2.first)) { @@ -862,8 +881,9 @@ void TheoryStrings::computeCareGraph(){ std::vector< TNode > reps; bool has_trigger_arg = false; for( unsigned j=0; j<f1.getNumChildren(); j++ ){ - reps.push_back( d_equalityEngine.getRepresentative( f1[j] ) ); - if( d_equalityEngine.isTriggerTerm( f1[j], THEORY_STRINGS ) ){ + reps.push_back(d_equalityEngine->getRepresentative(f1[j])); + if (d_equalityEngine->isTriggerTerm(f1[j], THEORY_STRINGS)) + { has_trigger_arg = true; } } @@ -889,7 +909,7 @@ void TheoryStrings::checkRegisterTermsPreNormalForm() const std::vector<Node>& seqc = d_bsolver.getStringEqc(); for (const Node& eqc : seqc) { - eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, &d_equalityEngine); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); while (!eqc_i.isFinished()) { Node n = (*eqc_i); diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index 2fb827429..500daac1f 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -70,20 +70,24 @@ class TheoryStrings : public Theory { const LogicInfo& logicInfo, ProofNodeManager* pnm); ~TheoryStrings(); + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; /** finish initialization */ void finishInit() override; - /** Get the theory rewriter of this class */ - TheoryRewriter* getTheoryRewriter() override; - /** Set the master equality engine */ - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + //--------------------------------- end initialization /** Identify this theory */ std::string identify() const override; /** Propagate */ void propagate(Effort e) override; /** Explain */ TrustNode explain(TNode literal) override; - /** Get the equality engine */ - eq::EqualityEngine* getEqualityEngine() override; /** Get current substitution */ bool getCurrentSubstitution(int effort, std::vector<Node>& vars, @@ -268,8 +272,6 @@ class TheoryStrings : public Theory { * theories is collected in this object. */ SequencesStatistics d_statistics; - /** Equaltity engine */ - eq::EqualityEngine d_equalityEngine; /** The solver state object */ SolverState d_state; /** The term registry for this theory */ diff --git a/src/theory/theory.cpp b/src/theory/theory.cpp index f1bfd052d..4f0cbdb6a 100644 --- a/src/theory/theory.cpp +++ b/src/theory/theory.cpp @@ -63,7 +63,6 @@ Theory::Theory(TheoryId id, ProofNodeManager* pnm, std::string name) : d_id(id), - d_instanceName(name), d_satContext(satContext), d_userContext(userContext), d_logicInfo(logicInfo), @@ -74,12 +73,15 @@ Theory::Theory(TheoryId id, d_careGraph(NULL), d_quantEngine(NULL), d_decManager(nullptr), + d_instanceName(name), d_checkTime(getStatsPrefix(id) + name + "::checkTime"), d_computeCareGraphTime(getStatsPrefix(id) + name + "::computeCareGraphTime"), d_sharedTerms(satContext), d_out(&out), d_valuation(valuation), + d_equalityEngine(nullptr), + d_allocEqualityEngine(nullptr), d_proofsEnabled(false) { smtStatisticsRegistry()->registerStat(&d_checkTime); @@ -91,7 +93,43 @@ Theory::~Theory() { smtStatisticsRegistry()->unregisterStat(&d_computeCareGraphTime); } -bool Theory::needsEqualityEngine(EeSetupInfo& esi) { return false; } +bool Theory::needsEqualityEngine(EeSetupInfo& esi) +{ + // by default, this theory does not use an (official) equality engine + return false; +} + +void Theory::setEqualityEngine(eq::EqualityEngine* ee) +{ + // set the equality engine pointer + d_equalityEngine = ee; +} +void Theory::setQuantifiersEngine(QuantifiersEngine* qe) +{ + Assert(d_quantEngine == nullptr); + d_quantEngine = qe; +} + +void Theory::setDecisionManager(DecisionManager* dm) +{ + Assert(d_decManager == nullptr); + Assert(dm != nullptr); + d_decManager = dm; +} + +void Theory::finishInitStandalone() +{ + EeSetupInfo esi; + if (needsEqualityEngine(esi)) + { + // always associated with the same SAT context as the theory (d_satContext) + d_allocEqualityEngine.reset(new eq::EqualityEngine( + *esi.d_notify, d_satContext, esi.d_name, esi.d_constantsAreTriggers)); + // use it as the official equality engine + d_equalityEngine = d_allocEqualityEngine.get(); + } + finishInit(); +} TheoryId Theory::theoryOf(options::TheoryOfMode mode, TNode node) { @@ -410,17 +448,10 @@ void Theory::getCareGraph(CareGraph* careGraph) { d_careGraph = NULL; } -void Theory::setQuantifiersEngine(QuantifiersEngine* qe) { - Assert(d_quantEngine == NULL); - Assert(qe != NULL); - d_quantEngine = qe; -} - -void Theory::setDecisionManager(DecisionManager* dm) +eq::EqualityEngine* Theory::getEqualityEngine() { - Assert(d_decManager == nullptr); - Assert(dm != nullptr); - d_decManager = dm; + // get the assigned equality engine, which is a pointer stored in this class + return d_equalityEngine; } }/* CVC4::theory namespace */ diff --git a/src/theory/theory.h b/src/theory/theory.h index ef06732fb..4feeac394 100644 --- a/src/theory/theory.h +++ b/src/theory/theory.h @@ -77,11 +77,35 @@ namespace eq { * RegisteredAttr works. (If you need multiple instances of the same * theory, you'll have to write a multiplexed theory that dispatches * all calls to them.) + * + * NOTE: A Theory has a special way of being initialized. The owner of a Theory + * is either: + * + * (A) Using Theory as a standalone object, not associated with a TheoryEngine. + * In this case, simply call the public initialization method + * (Theory::finishInitStandalone). + * + * (B) TheoryEngine, which determines how the Theory acts in accordance with + * its theory combination policy. We require the following steps in order: + * (B.1) Get information about whether the theory wishes to use an equality + * eninge, and more specifically which equality engine notifications the Theory + * would like to be notified of (Theory::needsEqualityEngine). + * (B.2) Set the equality engine of the theory (Theory::setEqualityEngine), + * which we refer to as the "official equality engine" of this Theory. The + * equality engine passed to the theory must respect the contract(s) specified + * by the equality engine setup information (EeSetupInfo) returned in the + * previous step. + * (B.3) Set the other required utilities including setQuantifiersEngine and + * setDecisionManager. + * (B.4) Call the private initialization method (Theory::finishInit). + * + * Initialization of the second form happens during TheoryEngine::finishInit, + * after the quantifiers engine and model objects have been set up. */ class Theory { - private: friend class ::CVC4::TheoryEngine; + private: // Disallow default construction, copy, assignment. Theory() = delete; Theory(const Theory&) = delete; @@ -90,11 +114,6 @@ class Theory { /** An integer identifying the type of the theory. */ TheoryId d_id; - /** Name of this theory instance. Along with the TheoryId this should provide - * an unique string identifier for each instance of a Theory class. We need - * this to ensure unique statistics names over multiple theory instances. */ - std::string d_instanceName; - /** The SAT search context for the Theory. */ context::Context* d_satContext; @@ -137,6 +156,10 @@ class Theory { DecisionManager* d_decManager; protected: + /** Name of this theory instance. Along with the TheoryId this should provide + * an unique string identifier for each instance of a Theory class. We need + * this to ensure unique statistics names over multiple theory instances. */ + std::string d_instanceName; // === STATISTICS === /** time spent in check calls */ @@ -222,7 +245,15 @@ class Theory { * theory engine (and other theories). */ Valuation d_valuation; - + /** + * Pointer to the official equality engine of this theory, which is owned by + * the equality engine manager of TheoryEngine. + */ + eq::EqualityEngine* d_equalityEngine; + /** + * The official equality engine, if we allocated it. + */ + std::unique_ptr<eq::EqualityEngine> d_allocEqualityEngine; /** * Whether proofs are enabled * @@ -264,17 +295,33 @@ class Theory { * its value must be computed (approximated) by the non-linear solver. */ bool isLegalElimination(TNode x, TNode val); + //--------------------------------- private initialization + /** + * Called to set the official equality engine. This should be done by + * TheoryEngine only. + */ + void setEqualityEngine(eq::EqualityEngine* ee); + /** Called to set the quantifiers engine. */ + void setQuantifiersEngine(QuantifiersEngine* qe); + /** Called to set the decision manager. */ + void setDecisionManager(DecisionManager* dm); + /** + * Finish theory initialization. At this point, options and the logic + * setting are final, the master equality engine and quantifiers + * engine (if any) are initialized, and the official equality engine of this + * theory has been assigned. This base class implementation + * does nothing. This should be called by TheoryEngine only. + */ + virtual void finishInit() {} + //--------------------------------- end private initialization public: //--------------------------------- initialization /** - * @return The theory rewriter associated with this theory. This is primarily - * called for the purposes of initializing the rewriter. + * @return The theory rewriter associated with this theory. */ virtual TheoryRewriter* getTheoryRewriter() = 0; /** - * !!!! TODO: use this method (https://github.com/orgs/CVC4/projects/39). - * * Returns true if this theory needs an equality engine for checking * satisfiability. * @@ -288,6 +335,13 @@ class Theory { * a notifications class (eq::EqualityEngineNotify). */ virtual bool needsEqualityEngine(EeSetupInfo& esi); + /** + * Finish theory initialization, standalone version. This is used to + * initialize this class if it is not associated with a theory engine. + * This allocates the official equality engine of this Theory and then + * calls the finishInit method above. + */ + void finishInitStandalone(); //--------------------------------- end initialization /** @@ -451,14 +505,6 @@ class Theory { DecisionManager* getDecisionManager() { return d_decManager; } /** - * Finish theory initialization. At this point, options and the logic - * setting are final, and the master equality engine and quantifiers - * engine (if any) are initialized. This base class implementation - * does nothing. - */ - virtual void finishInit() { } - - /** * Expand definitions in the term node. This returns a term that is * equivalent to node. It wraps this term in a TrustNode of kind * TrustNodeKind::REWRITE. If node is unchanged by this method, the @@ -513,14 +559,9 @@ class Theory { virtual void addSharedTerm(TNode n) { } /** - * Called to set the master equality engine. + * Get the official equality engine of this theory. */ - virtual void setMasterEqualityEngine(eq::EqualityEngine* eq) { } - - /** Called to set the quantifiers engine. */ - void setQuantifiersEngine(QuantifiersEngine* qe); - /** Called to set the decision manager. */ - void setDecisionManager(DecisionManager* dm); + eq::EqualityEngine* getEqualityEngine(); /** * Return the current theory care graph. Theories should overload @@ -855,9 +896,6 @@ class Theory { */ virtual std::pair<bool, Node> entailmentCheck(TNode lit); - /* equality engine TODO: use? */ - virtual eq::EqualityEngine* getEqualityEngine() { return NULL; } - /* get current substitution at an effort * input : vars * output : subs, exp diff --git a/src/theory/theory_engine.cpp b/src/theory/theory_engine.cpp index a88db4494..07c160058 100644 --- a/src/theory/theory_engine.cpp +++ b/src/theory/theory_engine.cpp @@ -42,6 +42,7 @@ #include "theory/bv/theory_bv_utils.h" #include "theory/care_graph.h" #include "theory/decision_manager.h" +#include "theory/ee_manager_distributed.h" #include "theory/quantifiers/first_order_model.h" #include "theory/quantifiers/fmf/model_engine.h" #include "theory/quantifiers/theory_quantifiers.h" @@ -129,20 +130,21 @@ std::string getTheoryString(theory::TheoryId id) } void TheoryEngine::finishInit() { - //initialize the quantifiers engine, master equality engine, model, model builder - if( d_logicInfo.isQuantified() ) { + // initialize the quantifiers engine + if (d_logicInfo.isQuantified()) + { // initialize the quantifiers engine d_quantEngine = new QuantifiersEngine(d_context, d_userContext, this); - Assert(d_masterEqualityEngine == 0); - d_masterEqualityEngine = new eq::EqualityEngine(d_masterEENotify,getSatContext(), "theory::master", false); + } - for(TheoryId theoryId = theory::THEORY_FIRST; theoryId != theory::THEORY_LAST; ++ theoryId) { - if (d_theoryTable[theoryId]) { - d_theoryTable[theoryId]->setQuantifiersEngine(d_quantEngine); - d_theoryTable[theoryId]->setMasterEqualityEngine(d_masterEqualityEngine); - } - } + // Initialize the equality engine architecture for all theories, which + // includes the master equality engine. + d_eeDistributed.reset(new EqEngineManagerDistributed(*this)); + d_eeDistributed->finishInit(); + // Initialize the model and model builder. + if (d_logicInfo.isQuantified()) + { d_curr_model_builder = d_quantEngine->getModelBuilder(); d_curr_model = d_quantEngine->getModel(); } else { @@ -150,25 +152,32 @@ void TheoryEngine::finishInit() { d_userContext, "DefaultModel", options::assignFunctionValues()); d_aloc_curr_model = true; } + //make the default builder, e.g. in the case that the quantifiers engine does not have a model builder if( d_curr_model_builder==NULL ){ d_curr_model_builder = new theory::TheoryEngineModelBuilder(this); d_aloc_curr_model_builder = true; } + // finish initializing the theories for(TheoryId theoryId = theory::THEORY_FIRST; theoryId != theory::THEORY_LAST; ++ theoryId) { - if (d_theoryTable[theoryId]) { - // set the decision manager for the theory - d_theoryTable[theoryId]->setDecisionManager(d_decManager.get()); - // finish initializing the theory - d_theoryTable[theoryId]->finishInit(); + Theory* t = d_theoryTable[theoryId]; + if (t == nullptr) + { + continue; } - } -} - -void TheoryEngine::eqNotifyNewClass(TNode t){ - if (d_logicInfo.isQuantified()) { - d_quantEngine->eqNotifyNewClass( t ); + // setup the pointers to the utilities + const EeTheoryInfo* eeti = d_eeDistributed->getEeTheoryInfo(theoryId); + Assert(eeti != nullptr); + // the theory's official equality engine is the one specified by the + // equality engine manager + t->setEqualityEngine(eeti->d_usedEe); + // set the quantifiers engine + t->setQuantifiersEngine(d_quantEngine); + // set the decision manager for the theory + t->setDecisionManager(d_decManager.get()); + // finish initializing the theory + t->finishInit(); } } @@ -182,8 +191,7 @@ TheoryEngine::TheoryEngine(context::Context* context, d_userContext(userContext), d_logicInfo(logicInfo), d_sharedTerms(this, context), - d_masterEqualityEngine(nullptr), - d_masterEENotify(*this), + d_eeDistributed(nullptr), d_quantEngine(nullptr), d_decManager(new DecisionManager(userContext)), d_curr_model(nullptr), @@ -252,8 +260,6 @@ TheoryEngine::~TheoryEngine() { delete d_quantEngine; - delete d_masterEqualityEngine; - smtStatisticsRegistry()->unregisterStat(&d_combineTheoriesTime); smtStatisticsRegistry()->unregisterStat(&d_arithSubstitutionsAdded); } @@ -537,9 +543,12 @@ void TheoryEngine::check(Theory::Effort effort) { Debug("theory") << ", need check = " << (needCheck() ? "YES" : "NO") << endl; if( Theory::fullEffort(effort) && !d_inConflict && !needCheck()) { - // case where we are about to answer SAT - if( d_masterEqualityEngine != NULL ){ - AlwaysAssert(d_masterEqualityEngine->consistent()); + // case where we are about to answer SAT, the master equality engine, + // if it exists, must be consistent. + eq::EqualityEngine* mee = getMasterEqualityEngine(); + if (mee != NULL) + { + AlwaysAssert(mee->consistent()); } if (d_curr_model->isBuilt()) { @@ -1793,6 +1802,17 @@ void TheoryEngine::staticInitializeBVOptions( } } +SharedTermsDatabase* TheoryEngine::getSharedTermsDatabase() +{ + return &d_sharedTerms; +} + +theory::eq::EqualityEngine* TheoryEngine::getMasterEqualityEngine() +{ + Assert(d_eeDistributed != nullptr); + return d_eeDistributed->getMasterEqualityEngine(); +} + void TheoryEngine::getExplanation(std::vector<NodeTheoryPair>& explanationVector, LemmaProofRecipe* proofRecipe) { Assert(explanationVector.size() > 0); diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index 081d53098..aa23aa29b 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -90,6 +90,7 @@ struct NodeTheoryPairHashFunction { namespace theory { class TheoryModel; class TheoryEngineModelBuilder; + class EqEngineManagerDistributed; namespace eq { class EqualityEngine; @@ -148,43 +149,13 @@ class TheoryEngine { SharedTermsDatabase d_sharedTerms; /** - * Master equality engine, to share with theories. + * The distributed equality manager. This class is responsible for + * configuring the theories of this class for handling equalties + * in a "distributed" fashion, i.e. each theory maintains a unique + * instance of an equality engine. These equality engines are memory + * managed by this class. */ - theory::eq::EqualityEngine* d_masterEqualityEngine; - - /** notify class for master equality engine */ - class NotifyClass : public theory::eq::EqualityEngineNotify { - TheoryEngine& d_te; - public: - NotifyClass(TheoryEngine& te): d_te(te) {} - bool eqNotifyTriggerEquality(TNode equality, bool value) override - { - return true; - } - bool eqNotifyTriggerPredicate(TNode predicate, bool value) override - { - return true; - } - bool eqNotifyTriggerTermEquality(theory::TheoryId tag, - TNode t1, - TNode t2, - bool value) override - { - return true; - } - void eqNotifyConstantTermMerge(TNode t1, TNode t2) override {} - void eqNotifyNewClass(TNode t) override { d_te.eqNotifyNewClass(t); } - void eqNotifyMerge(TNode t1, TNode t2) override {} - void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override - { - } - };/* class TheoryEngine::NotifyClass */ - NotifyClass d_masterEENotify; - - /** - * notification methods - */ - void eqNotifyNewClass(TNode t); + std::unique_ptr<theory::EqEngineManagerDistributed> d_eeDistributed; /** * The quantifiers engine @@ -389,7 +360,13 @@ class TheoryEngine { d_propEngine = propEngine; } - /** Called when all initialization of options/logic is done */ + /** + * Called when all initialization of options/logic is done, after theory + * objects have been created. + * + * This initializes the quantifiers engine, the "official" equality engines + * of each theory as required, and the model and model builder utilities. + */ void finishInit(); /** @@ -759,13 +736,9 @@ public: public: void staticInitializeBVOptions(const std::vector<Node>& assertions); - Node ppSimpITE(TNode assertion); - /** Returns false if an assertion simplified to false. */ - bool donePPSimpITE(std::vector<Node>& assertions); - - SharedTermsDatabase* getSharedTermsDatabase() { return &d_sharedTerms; } + SharedTermsDatabase* getSharedTermsDatabase(); - theory::eq::EqualityEngine* getMasterEqualityEngine() { return d_masterEqualityEngine; } + theory::eq::EqualityEngine* getMasterEqualityEngine(); SortInference* getSortInference() { return &d_sortInfer; } diff --git a/src/theory/uf/theory_uf.cpp b/src/theory/uf/theory_uf.cpp index 862a906a0..4f9c3bed5 100644 --- a/src/theory/uf/theory_uf.cpp +++ b/src/theory/uf/theory_uf.cpp @@ -54,16 +54,12 @@ TheoryUF::TheoryUF(context::Context* c, * so make sure it's initialized first. */ d_thss(nullptr), d_ho(nullptr), - d_equalityEngine(d_notify, c, instanceName + "theory::uf::ee", true), d_conflict(c, false), d_functionsTerms(c), d_symb(u, instanceName) { d_true = NodeManager::currentNM()->mkConst( true ); - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::APPLY_UF, false, options::ufHo()); - ProofChecker* pc = pnm != nullptr ? pnm->getChecker() : nullptr; if (pc != nullptr) { @@ -74,11 +70,17 @@ TheoryUF::TheoryUF(context::Context* c, TheoryUF::~TheoryUF() { } -void TheoryUF::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); +TheoryRewriter* TheoryUF::getTheoryRewriter() { return &d_rewriter; } + +bool TheoryUF::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = d_instanceName + "theory::uf::ee"; + return true; } void TheoryUF::finishInit() { + Assert(d_equalityEngine != nullptr); // combined cardinality constraints are not evaluated in getModelValue d_valuation.setUnevaluatedKind(kind::COMBINED_CARDINALITY_CONSTRAINT); // Initialize the cardinality constraints solver if the logic includes UF, @@ -90,9 +92,11 @@ void TheoryUF::finishInit() { d_thss.reset(new CardinalityExtension( getSatContext(), getUserContext(), *d_out, this)); } + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::APPLY_UF, false, options::ufHo()); if (options::ufHo()) { - d_equalityEngine.addFunctionKind(kind::HO_APPLY); + d_equalityEngine->addFunctionKind(kind::HO_APPLY); d_ho.reset(new HoExtension(*this, getSatContext(), getUserContext())); } } @@ -148,7 +152,7 @@ void TheoryUF::check(Effort level) { bool polarity = fact.getKind() != kind::NOT; TNode atom = polarity ? fact : fact[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.assertEquality(atom, polarity, fact); + d_equalityEngine->assertEquality(atom, polarity, fact); if( options::ufHo() && options::ufHoExt() ){ if( !polarity && !d_conflict && atom[0].getType().isFunction() ){ // apply extensionality eagerly using the ho extension @@ -169,10 +173,10 @@ void TheoryUF::check(Effort level) { } //needed for models if( options::produceModels() ){ - d_equalityEngine.assertPredicate(atom, polarity, fact); + d_equalityEngine->assertPredicate(atom, polarity, fact); } } else { - d_equalityEngine.assertPredicate(atom, polarity, fact); + d_equalityEngine->assertPredicate(atom, polarity, fact); } } @@ -198,7 +202,7 @@ Node TheoryUF::getOperatorForApplyTerm( TNode node ) { if( node.getKind()==kind::APPLY_UF ){ return node.getOperator(); }else{ - return d_equalityEngine.getRepresentative( node[0] ); + return d_equalityEngine->getRepresentative(node[0]); } } @@ -242,17 +246,17 @@ void TheoryUF::preRegisterTerm(TNode node) { switch (node.getKind()) { case kind::EQUAL: // Add the trigger for equality - d_equalityEngine.addTriggerEquality(node); + d_equalityEngine->addTriggerEquality(node); break; case kind::APPLY_UF: case kind::HO_APPLY: // Maybe it's a predicate if (node.getType().isBoolean()) { // Get triggered for both equal and dis-equal - d_equalityEngine.addTriggerPredicate(node); + d_equalityEngine->addTriggerPredicate(node); } else { // Function applications/predicates - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); } // Remember the function and predicate terms d_functionsTerms.push_back(node); @@ -263,7 +267,7 @@ void TheoryUF::preRegisterTerm(TNode node) { break; default: // Variables etc - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); break; } }/* TheoryUF::preRegisterTerm() */ @@ -294,9 +298,10 @@ void TheoryUF::explain(TNode literal, std::vector<TNode>& assumptions, eq::EqPro bool polarity = literal.getKind() != kind::NOT; TNode atom = polarity ? literal : literal[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions, pf); + d_equalityEngine->explainEquality( + atom[0], atom[1], polarity, assumptions, pf); } else { - d_equalityEngine.explainPredicate(atom, polarity, assumptions, pf); + d_equalityEngine->explainPredicate(atom, polarity, assumptions, pf); } if( pf ){ Debug("pf::uf") << std::endl; @@ -331,7 +336,7 @@ bool TheoryUF::collectModelInfo(TheoryModel* m) // Compute terms appearing in assertions and shared terms computeRelevantTerms(termSet); - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { Trace("uf") << "Collect model info fail UF" << std::endl; return false; @@ -495,13 +500,15 @@ void TheoryUF::ppStaticLearn(TNode n, NodeBuilder<>& learned) { EqualityStatus TheoryUF::getEqualityStatus(TNode a, TNode b) { // Check for equality (simplest) - if (d_equalityEngine.areEqual(a, b)) { + if (d_equalityEngine->areEqual(a, b)) + { // The terms are implied to be equal return EQUALITY_TRUE; } // Check for disequality - if (d_equalityEngine.areDisequal(a, b, false)) { + if (d_equalityEngine->areDisequal(a, b, false)) + { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -512,15 +519,19 @@ EqualityStatus TheoryUF::getEqualityStatus(TNode a, TNode b) { void TheoryUF::addSharedTerm(TNode t) { Debug("uf::sharing") << "TheoryUF::addSharedTerm(" << t << ")" << std::endl; - d_equalityEngine.addTriggerTerm(t, THEORY_UF); + d_equalityEngine->addTriggerTerm(t, THEORY_UF); } bool TheoryUF::areCareDisequal(TNode x, TNode y){ - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); - if( d_equalityEngine.isTriggerTerm(x, THEORY_UF) && d_equalityEngine.isTriggerTerm(y, THEORY_UF) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_UF); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_UF); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); + if (d_equalityEngine->isTriggerTerm(x, THEORY_UF) + && d_equalityEngine->isTriggerTerm(y, THEORY_UF)) + { + TNode x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_UF); + TNode y_shared = + d_equalityEngine->getTriggerTermRepresentative(y, THEORY_UF); EqualityStatus eqStatus = d_valuation.getEqualityStatus(x_shared, y_shared); if( eqStatus==EQUALITY_FALSE_AND_PROPAGATED || eqStatus==EQUALITY_FALSE || eqStatus==EQUALITY_FALSE_IN_MODEL ){ return true; @@ -538,21 +549,27 @@ void TheoryUF::addCarePairs(TNodeTrie* t1, if( t2!=NULL ){ Node f1 = t1->getData(); Node f2 = t2->getData(); - if( !d_equalityEngine.areEqual( f1, f2 ) ){ + if (!d_equalityEngine->areEqual(f1, f2)) + { Debug("uf::sharing") << "TheoryUf::computeCareGraph(): checking function " << f1 << " and " << f2 << std::endl; vector< pair<TNode, TNode> > currentPairs; unsigned arg_start_index = getArgumentStartIndexForApplyTerm( f1 ); for (unsigned k = arg_start_index; k < f1.getNumChildren(); ++ k) { TNode x = f1[k]; TNode y = f2[k]; - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); - Assert(!d_equalityEngine.areDisequal(x, y, false)); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); + Assert(!d_equalityEngine->areDisequal(x, y, false)); Assert(!areCareDisequal(x, y)); - if( !d_equalityEngine.areEqual( x, y ) ){ - if( d_equalityEngine.isTriggerTerm(x, THEORY_UF) && d_equalityEngine.isTriggerTerm(y, THEORY_UF) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_UF); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_UF); + if (!d_equalityEngine->areEqual(x, y)) + { + if (d_equalityEngine->isTriggerTerm(x, THEORY_UF) + && d_equalityEngine->isTriggerTerm(y, THEORY_UF)) + { + TNode x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_UF); + TNode y_shared = + d_equalityEngine->getTriggerTermRepresentative(y, THEORY_UF); currentPairs.push_back(make_pair(x_shared, y_shared)); } } @@ -580,7 +597,8 @@ void TheoryUF::addCarePairs(TNodeTrie* t1, std::map<TNode, TNodeTrie>::iterator it2 = it; ++it2; for( ; it2 != t1->d_data.end(); ++it2 ){ - if( !d_equalityEngine.areDisequal(it->first, it2->first, false) ){ + if (!d_equalityEngine->areDisequal(it->first, it2->first, false)) + { if( !areCareDisequal(it->first, it2->first) ){ addCarePairs( &it->second, &it2->second, arity, depth+1 ); } @@ -593,7 +611,7 @@ void TheoryUF::addCarePairs(TNodeTrie* t1, { for (std::pair<const TNode, TNodeTrie>& tt2 : t2->d_data) { - if (!d_equalityEngine.areDisequal(tt1.first, tt2.first, false)) + if (!d_equalityEngine->areDisequal(tt1.first, tt2.first, false)) { if (!areCareDisequal(tt1.first, tt2.first)) { @@ -621,8 +639,9 @@ void TheoryUF::computeCareGraph() { std::vector< TNode > reps; bool has_trigger_arg = false; for( unsigned j=arg_start_index; j<f1.getNumChildren(); j++ ){ - reps.push_back( d_equalityEngine.getRepresentative( f1[j] ) ); - if( d_equalityEngine.isTriggerTerm( f1[j], THEORY_UF ) ){ + reps.push_back(d_equalityEngine->getRepresentative(f1[j])); + if (d_equalityEngine->isTriggerTerm(f1[j], THEORY_UF)) + { has_trigger_arg = true; } } diff --git a/src/theory/uf/theory_uf.h b/src/theory/uf/theory_uf.h index 345547301..001c947e9 100644 --- a/src/theory/uf/theory_uf.h +++ b/src/theory/uf/theory_uf.h @@ -116,9 +116,6 @@ private: /** the higher-order solver extension (or nullptr if it does not exist) */ std::unique_ptr<HoExtension> d_ho; - /** Equaltity engine */ - eq::EqualityEngine d_equalityEngine; - /** Are we in conflict */ context::CDO<bool> d_conflict; @@ -186,10 +183,18 @@ private: ~TheoryUF(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ void finishInit() override; + //--------------------------------- end initialization void check(Effort) override; TrustNode expandDefinition(Node node) override; @@ -210,8 +215,6 @@ private: std::string identify() const override { return "THEORY_UF"; } - eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } - /** get a pointer to the uf with cardinality */ CardinalityExtension* getCardinalityExtension() const { return d_thss.get(); } /** are we in conflict? */ |