diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2021-02-19 09:26:36 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-19 09:26:36 -0600 |
commit | c4822869beac8d4a0eac4b234e0662d3db49f995 (patch) | |
tree | d34b86c54b0ac8de6df4734e1e76afa5f43d5efb /src/theory/bags | |
parent | 00479d03cdeac3e864a1930dddb16c71c5bf2ce9 (diff) |
Refactoring theory inference process (#5920)
This PR refactors TheoryInference so that it is not responsible for calling back into InferenceManager, instead it sets data so that InferenceManagerBuffered has enough information to do this itself. It also makes the decision of whether to cache lemmas in theory inference manager a global choice per-theory instead of per-lemma.
Diffstat (limited to 'src/theory/bags')
-rw-r--r-- | src/theory/bags/bag_solver.cpp | 22 | ||||
-rw-r--r-- | src/theory/bags/infer_info.cpp | 35 | ||||
-rw-r--r-- | src/theory/bags/infer_info.h | 12 | ||||
-rw-r--r-- | src/theory/bags/inference_generator.cpp | 25 | ||||
-rw-r--r-- | src/theory/bags/inference_generator.h | 6 | ||||
-rw-r--r-- | src/theory/bags/inference_manager.h | 1 | ||||
-rw-r--r-- | src/theory/bags/theory_bags.cpp | 2 |
7 files changed, 51 insertions, 52 deletions
diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp index bdd4a9b30..76c001ba2 100644 --- a/src/theory/bags/bag_solver.cpp +++ b/src/theory/bags/bag_solver.cpp @@ -27,7 +27,7 @@ namespace theory { namespace bags { BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr) - : d_state(s), d_ig(&d_state), d_im(im), d_termReg(tr) + : d_state(s), d_ig(&s, &im), d_im(im), d_termReg(tr) { d_zero = NodeManager::currentNM()->mkConst(Rational(0)); d_one = NodeManager::currentNM()->mkConst(Rational(1)); @@ -102,7 +102,7 @@ void BagSolver::checkEmpty(const Node& n) for (const Node& e : d_state.getElements(n)) { InferInfo i = d_ig.empty(n, e); - i.process(&d_im, true); + d_im.lemmaTheoryInference(&i); } } @@ -113,7 +113,7 @@ void BagSolver::checkUnionDisjoint(const Node& n) for (const Node& e : elements) { InferInfo i = d_ig.unionDisjoint(n, e); - i.process(&d_im, true); + d_im.lemmaTheoryInference(&i); } } @@ -124,7 +124,7 @@ void BagSolver::checkUnionMax(const Node& n) for (const Node& e : elements) { InferInfo i = d_ig.unionMax(n, e); - i.process(&d_im, true); + d_im.lemmaTheoryInference(&i); } } @@ -135,7 +135,7 @@ void BagSolver::checkIntersectionMin(const Node& n) for (const Node& e : elements) { InferInfo i = d_ig.intersection(n, e); - i.process(&d_im, true); + d_im.lemmaTheoryInference(&i); } } @@ -146,7 +146,7 @@ void BagSolver::checkDifferenceSubtract(const Node& n) for (const Node& e : elements) { InferInfo i = d_ig.differenceSubtract(n, e); - i.process(&d_im, true); + d_im.lemmaTheoryInference(&i); } } @@ -159,13 +159,13 @@ void BagSolver::checkMkBag(const Node& n) for (const Node& e : d_state.getElements(n)) { InferInfo i = d_ig.mkBag(n, e); - i.process(&d_im, true); + d_im.lemmaTheoryInference(&i); } } void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element) { InferInfo i = d_ig.nonNegativeCount(bag, element); - i.process(&d_im, true); + d_im.lemmaTheoryInference(&i); } void BagSolver::checkDifferenceRemove(const Node& n) @@ -175,7 +175,7 @@ void BagSolver::checkDifferenceRemove(const Node& n) for (const Node& e : elements) { InferInfo i = d_ig.differenceRemove(n, e); - i.process(&d_im, true); + d_im.lemmaTheoryInference(&i); } } @@ -192,7 +192,7 @@ void BagSolver::checkDuplicateRemoval(Node n) for (const Node& e : elements) { InferInfo i = d_ig.duplicateRemoval(n, e); - i.process(&d_im, true); + d_im.lemmaTheoryInference(&i); } } @@ -201,7 +201,7 @@ void BagSolver::checkDisequalBagTerms() for (const Node& n : d_state.getDisequalBagTerms()) { InferInfo info = d_ig.bagDisequality(n); - info.process(&d_im, true); + d_im.lemmaTheoryInference(&info); } } diff --git a/src/theory/bags/infer_info.cpp b/src/theory/bags/infer_info.cpp index 0655b6bbf..969c6b3de 100644 --- a/src/theory/bags/infer_info.cpp +++ b/src/theory/bags/infer_info.cpp @@ -20,39 +20,28 @@ namespace CVC4 { namespace theory { namespace bags { -InferInfo::InferInfo(InferenceId id) : TheoryInference(id) {} +InferInfo::InferInfo(TheoryInferenceManager* im, InferenceId id) + : TheoryInference(id), d_im(im) +{ +} -bool InferInfo::process(TheoryInferenceManager* im, bool asLemma) +TrustNode InferInfo::processLemma(LemmaProperty& p) { - Node lemma = d_conclusion; - if (d_premises.size() >= 2) - { - Node andNode = NodeManager::currentNM()->mkNode(kind::AND, d_premises); - lemma = andNode.impNode(lemma); - } - else if (d_premises.size() == 1) - { - lemma = d_premises[0].impNode(lemma); - } - if (asLemma) - { - TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr); - im->trustedLemma(trustedLemma, getId()); - } - else - { - Unimplemented(); - } + NodeManager* nm = NodeManager::currentNM(); + Node pnode = nm->mkAnd(d_premises); + Node lemma = nm->mkNode(kind::IMPLIES, pnode, d_conclusion); + + // send lemmas corresponding to the skolems introduced for (const auto& pair : d_skolems) { Node n = pair.first.eqNode(pair.second); TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr); - im->trustedLemma(trustedLemma, getId()); + d_im->trustedLemma(trustedLemma, getId(), p); } Trace("bags::InferInfo::process") << (*this) << std::endl; - return true; + return TrustNode::mkTrustLemma(lemma, nullptr); } bool InferInfo::isTrivial() const diff --git a/src/theory/bags/infer_info.h b/src/theory/bags/infer_info.h index 66d75b5dc..a533acf58 100644 --- a/src/theory/bags/infer_info.h +++ b/src/theory/bags/infer_info.h @@ -26,9 +26,11 @@ namespace CVC4 { namespace theory { + +class TheoryInferenceManager; + namespace bags { -class InferenceManager; /** * An inference. This is a class to track an unprocessed call to either @@ -38,10 +40,12 @@ class InferenceManager; class InferInfo : public TheoryInference { public: - InferInfo(InferenceId id); + InferInfo(TheoryInferenceManager* im, InferenceId id); ~InferInfo() {} - /** Process this inference */ - bool process(TheoryInferenceManager* im, bool asLemma) override; + /** Process lemma */ + TrustNode processLemma(LemmaProperty& p) override; + /** Pointer to the class used for processing this info */ + TheoryInferenceManager* d_im; /** The conclusion */ Node d_conclusion; /** diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp index 2d4a5afed..bc1ed771c 100644 --- a/src/theory/bags/inference_generator.cpp +++ b/src/theory/bags/inference_generator.cpp @@ -23,7 +23,8 @@ namespace CVC4 { namespace theory { namespace bags { -InferenceGenerator::InferenceGenerator(SolverState* state) : d_state(state) +InferenceGenerator::InferenceGenerator(SolverState* state, InferenceManager* im) + : d_state(state), d_im(im) { d_nm = NodeManager::currentNM(); d_sm = d_nm->getSkolemManager(); @@ -37,7 +38,7 @@ InferInfo InferenceGenerator::nonNegativeCount(Node n, Node e) Assert(n.getType().isBag()); Assert(e.getType() == n.getType().getBagElementType()); - InferInfo inferInfo(InferenceId::BAG_NON_NEGATIVE_COUNT); + InferInfo inferInfo(d_im, InferenceId::BAG_NON_NEGATIVE_COUNT); Node count = d_nm->mkNode(kind::BAG_COUNT, e, n); Node gte = d_nm->mkNode(kind::GEQ, count, d_zero); @@ -54,7 +55,7 @@ InferInfo InferenceGenerator::mkBag(Node n, Node e) { // TODO issue #78: refactor this with BagRewriter // (=> true (= (bag.count e (bag e c)) c)) - InferInfo inferInfo(InferenceId::BAG_MK_BAG_SAME_ELEMENT); + InferInfo inferInfo(d_im, InferenceId::BAG_MK_BAG_SAME_ELEMENT); Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); inferInfo.d_conclusion = count.eqNode(n[1]); @@ -65,7 +66,7 @@ InferInfo InferenceGenerator::mkBag(Node n, Node e) // (=> // true // (= (bag.count e (bag x c)) (ite (= e x) c 0))) - InferInfo inferInfo(InferenceId::BAG_MK_BAG); + InferInfo inferInfo(d_im, InferenceId::BAG_MK_BAG); Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); Node same = d_nm->mkNode(kind::EQUAL, n[0], e); @@ -88,7 +89,7 @@ InferInfo InferenceGenerator::bagDisequality(Node n) Node A = n[0]; Node B = n[1]; - InferInfo inferInfo(InferenceId::BAG_DISEQUALITY); + InferInfo inferInfo(d_im, InferenceId::BAG_DISEQUALITY); TypeNode elementType = A.getType().getBagElementType(); BoundVarManager* bvm = d_nm->getBoundVarManager(); @@ -121,7 +122,7 @@ InferInfo InferenceGenerator::empty(Node n, Node e) Assert(n.getKind() == kind::EMPTYBAG); Assert(e.getType() == n.getType().getBagElementType()); - InferInfo inferInfo(InferenceId::BAG_EMPTY); + InferInfo inferInfo(d_im, InferenceId::BAG_EMPTY); Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); @@ -137,7 +138,7 @@ InferInfo InferenceGenerator::unionDisjoint(Node n, Node e) Node A = n[0]; Node B = n[1]; - InferInfo inferInfo(InferenceId::BAG_UNION_DISJOINT); + InferInfo inferInfo(d_im, InferenceId::BAG_UNION_DISJOINT); Node countA = getMultiplicityTerm(e, A); Node countB = getMultiplicityTerm(e, B); @@ -159,7 +160,7 @@ InferInfo InferenceGenerator::unionMax(Node n, Node e) Node A = n[0]; Node B = n[1]; - InferInfo inferInfo(InferenceId::BAG_UNION_MAX); + InferInfo inferInfo(d_im, InferenceId::BAG_UNION_MAX); Node countA = getMultiplicityTerm(e, A); Node countB = getMultiplicityTerm(e, B); @@ -182,7 +183,7 @@ InferInfo InferenceGenerator::intersection(Node n, Node e) Node A = n[0]; Node B = n[1]; - InferInfo inferInfo(InferenceId::BAG_INTERSECTION_MIN); + InferInfo inferInfo(d_im, InferenceId::BAG_INTERSECTION_MIN); Node countA = getMultiplicityTerm(e, A); Node countB = getMultiplicityTerm(e, B); @@ -203,7 +204,7 @@ InferInfo InferenceGenerator::differenceSubtract(Node n, Node e) Node A = n[0]; Node B = n[1]; - InferInfo inferInfo(InferenceId::BAG_DIFFERENCE_SUBTRACT); + InferInfo inferInfo(d_im, InferenceId::BAG_DIFFERENCE_SUBTRACT); Node countA = getMultiplicityTerm(e, A); Node countB = getMultiplicityTerm(e, B); @@ -225,7 +226,7 @@ InferInfo InferenceGenerator::differenceRemove(Node n, Node e) Node A = n[0]; Node B = n[1]; - InferInfo inferInfo(InferenceId::BAG_DIFFERENCE_REMOVE); + InferInfo inferInfo(d_im, InferenceId::BAG_DIFFERENCE_REMOVE); Node countA = getMultiplicityTerm(e, A); Node countB = getMultiplicityTerm(e, B); @@ -246,7 +247,7 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e) Assert(e.getType() == n[0].getType().getBagElementType()); Node A = n[0]; - InferInfo inferInfo(InferenceId::BAG_DUPLICATE_REMOVAL); + InferInfo inferInfo(d_im, InferenceId::BAG_DUPLICATE_REMOVAL); Node countA = getMultiplicityTerm(e, A); Node skolem = getSkolem(n, inferInfo); diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h index 4a852530a..f10a1051f 100644 --- a/src/theory/bags/inference_generator.h +++ b/src/theory/bags/inference_generator.h @@ -22,6 +22,7 @@ #include "expr/node.h" #include "infer_info.h" +#include "theory/bags/inference_manager.h" #include "theory/bags/solver_state.h" namespace CVC4 { @@ -35,7 +36,7 @@ namespace bags { class InferenceGenerator { public: - InferenceGenerator(SolverState* state); + InferenceGenerator(SolverState* state, InferenceManager* im); /** * @param A is a bag of type (Bag E) @@ -179,6 +180,9 @@ class InferenceGenerator NodeManager* d_nm; SkolemManager* d_sm; SolverState* d_state; + /** Pointer to the inference manager */ + InferenceManager* d_im; + /** Commonly used constants */ Node d_true; Node d_zero; Node d_one; diff --git a/src/theory/bags/inference_manager.h b/src/theory/bags/inference_manager.h index 71a014582..1b132fc37 100644 --- a/src/theory/bags/inference_manager.h +++ b/src/theory/bags/inference_manager.h @@ -17,6 +17,7 @@ #ifndef CVC4__THEORY__BAGS__INFERENCE_MANAGER_H #define CVC4__THEORY__BAGS__INFERENCE_MANAGER_H +#include "theory/bags/infer_info.h" #include "theory/bags/solver_state.h" #include "theory/inference_manager_buffered.h" diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 6df44295e..48fc38b8f 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -31,7 +31,7 @@ TheoryBags::TheoryBags(context::Context* c, : Theory(THEORY_BAGS, c, u, out, valuation, logicInfo, pnm), d_state(c, u, valuation), d_im(*this, d_state, nullptr), - d_ig(&d_state), + d_ig(&d_state, &d_im), d_notify(*this, d_im), d_statistics(), d_rewriter(&d_statistics.d_rewrites), |