summaryrefslogtreecommitdiff
path: root/src/theory/bags
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2021-02-19 09:26:36 -0600
committerGitHub <noreply@github.com>2021-02-19 09:26:36 -0600
commitc4822869beac8d4a0eac4b234e0662d3db49f995 (patch)
treed34b86c54b0ac8de6df4734e1e76afa5f43d5efb /src/theory/bags
parent00479d03cdeac3e864a1930dddb16c71c5bf2ce9 (diff)
Refactoring theory inference process (#5920)
This PR refactors TheoryInference so that it is not responsible for calling back into InferenceManager, instead it sets data so that InferenceManagerBuffered has enough information to do this itself. It also makes the decision of whether to cache lemmas in theory inference manager a global choice per-theory instead of per-lemma.
Diffstat (limited to 'src/theory/bags')
-rw-r--r--src/theory/bags/bag_solver.cpp22
-rw-r--r--src/theory/bags/infer_info.cpp35
-rw-r--r--src/theory/bags/infer_info.h12
-rw-r--r--src/theory/bags/inference_generator.cpp25
-rw-r--r--src/theory/bags/inference_generator.h6
-rw-r--r--src/theory/bags/inference_manager.h1
-rw-r--r--src/theory/bags/theory_bags.cpp2
7 files changed, 51 insertions, 52 deletions
diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp
index bdd4a9b30..76c001ba2 100644
--- a/src/theory/bags/bag_solver.cpp
+++ b/src/theory/bags/bag_solver.cpp
@@ -27,7 +27,7 @@ namespace theory {
namespace bags {
BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr)
- : d_state(s), d_ig(&d_state), d_im(im), d_termReg(tr)
+ : d_state(s), d_ig(&s, &im), d_im(im), d_termReg(tr)
{
d_zero = NodeManager::currentNM()->mkConst(Rational(0));
d_one = NodeManager::currentNM()->mkConst(Rational(1));
@@ -102,7 +102,7 @@ void BagSolver::checkEmpty(const Node& n)
for (const Node& e : d_state.getElements(n))
{
InferInfo i = d_ig.empty(n, e);
- i.process(&d_im, true);
+ d_im.lemmaTheoryInference(&i);
}
}
@@ -113,7 +113,7 @@ void BagSolver::checkUnionDisjoint(const Node& n)
for (const Node& e : elements)
{
InferInfo i = d_ig.unionDisjoint(n, e);
- i.process(&d_im, true);
+ d_im.lemmaTheoryInference(&i);
}
}
@@ -124,7 +124,7 @@ void BagSolver::checkUnionMax(const Node& n)
for (const Node& e : elements)
{
InferInfo i = d_ig.unionMax(n, e);
- i.process(&d_im, true);
+ d_im.lemmaTheoryInference(&i);
}
}
@@ -135,7 +135,7 @@ void BagSolver::checkIntersectionMin(const Node& n)
for (const Node& e : elements)
{
InferInfo i = d_ig.intersection(n, e);
- i.process(&d_im, true);
+ d_im.lemmaTheoryInference(&i);
}
}
@@ -146,7 +146,7 @@ void BagSolver::checkDifferenceSubtract(const Node& n)
for (const Node& e : elements)
{
InferInfo i = d_ig.differenceSubtract(n, e);
- i.process(&d_im, true);
+ d_im.lemmaTheoryInference(&i);
}
}
@@ -159,13 +159,13 @@ void BagSolver::checkMkBag(const Node& n)
for (const Node& e : d_state.getElements(n))
{
InferInfo i = d_ig.mkBag(n, e);
- i.process(&d_im, true);
+ d_im.lemmaTheoryInference(&i);
}
}
void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
{
InferInfo i = d_ig.nonNegativeCount(bag, element);
- i.process(&d_im, true);
+ d_im.lemmaTheoryInference(&i);
}
void BagSolver::checkDifferenceRemove(const Node& n)
@@ -175,7 +175,7 @@ void BagSolver::checkDifferenceRemove(const Node& n)
for (const Node& e : elements)
{
InferInfo i = d_ig.differenceRemove(n, e);
- i.process(&d_im, true);
+ d_im.lemmaTheoryInference(&i);
}
}
@@ -192,7 +192,7 @@ void BagSolver::checkDuplicateRemoval(Node n)
for (const Node& e : elements)
{
InferInfo i = d_ig.duplicateRemoval(n, e);
- i.process(&d_im, true);
+ d_im.lemmaTheoryInference(&i);
}
}
@@ -201,7 +201,7 @@ void BagSolver::checkDisequalBagTerms()
for (const Node& n : d_state.getDisequalBagTerms())
{
InferInfo info = d_ig.bagDisequality(n);
- info.process(&d_im, true);
+ d_im.lemmaTheoryInference(&info);
}
}
diff --git a/src/theory/bags/infer_info.cpp b/src/theory/bags/infer_info.cpp
index 0655b6bbf..969c6b3de 100644
--- a/src/theory/bags/infer_info.cpp
+++ b/src/theory/bags/infer_info.cpp
@@ -20,39 +20,28 @@ namespace CVC4 {
namespace theory {
namespace bags {
-InferInfo::InferInfo(InferenceId id) : TheoryInference(id) {}
+InferInfo::InferInfo(TheoryInferenceManager* im, InferenceId id)
+ : TheoryInference(id), d_im(im)
+{
+}
-bool InferInfo::process(TheoryInferenceManager* im, bool asLemma)
+TrustNode InferInfo::processLemma(LemmaProperty& p)
{
- Node lemma = d_conclusion;
- if (d_premises.size() >= 2)
- {
- Node andNode = NodeManager::currentNM()->mkNode(kind::AND, d_premises);
- lemma = andNode.impNode(lemma);
- }
- else if (d_premises.size() == 1)
- {
- lemma = d_premises[0].impNode(lemma);
- }
- if (asLemma)
- {
- TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
- im->trustedLemma(trustedLemma, getId());
- }
- else
- {
- Unimplemented();
- }
+ NodeManager* nm = NodeManager::currentNM();
+ Node pnode = nm->mkAnd(d_premises);
+ Node lemma = nm->mkNode(kind::IMPLIES, pnode, d_conclusion);
+
+ // send lemmas corresponding to the skolems introduced
for (const auto& pair : d_skolems)
{
Node n = pair.first.eqNode(pair.second);
TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr);
- im->trustedLemma(trustedLemma, getId());
+ d_im->trustedLemma(trustedLemma, getId(), p);
}
Trace("bags::InferInfo::process") << (*this) << std::endl;
- return true;
+ return TrustNode::mkTrustLemma(lemma, nullptr);
}
bool InferInfo::isTrivial() const
diff --git a/src/theory/bags/infer_info.h b/src/theory/bags/infer_info.h
index 66d75b5dc..a533acf58 100644
--- a/src/theory/bags/infer_info.h
+++ b/src/theory/bags/infer_info.h
@@ -26,9 +26,11 @@
namespace CVC4 {
namespace theory {
+
+class TheoryInferenceManager;
+
namespace bags {
-class InferenceManager;
/**
* An inference. This is a class to track an unprocessed call to either
@@ -38,10 +40,12 @@ class InferenceManager;
class InferInfo : public TheoryInference
{
public:
- InferInfo(InferenceId id);
+ InferInfo(TheoryInferenceManager* im, InferenceId id);
~InferInfo() {}
- /** Process this inference */
- bool process(TheoryInferenceManager* im, bool asLemma) override;
+ /** Process lemma */
+ TrustNode processLemma(LemmaProperty& p) override;
+ /** Pointer to the class used for processing this info */
+ TheoryInferenceManager* d_im;
/** The conclusion */
Node d_conclusion;
/**
diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp
index 2d4a5afed..bc1ed771c 100644
--- a/src/theory/bags/inference_generator.cpp
+++ b/src/theory/bags/inference_generator.cpp
@@ -23,7 +23,8 @@ namespace CVC4 {
namespace theory {
namespace bags {
-InferenceGenerator::InferenceGenerator(SolverState* state) : d_state(state)
+InferenceGenerator::InferenceGenerator(SolverState* state, InferenceManager* im)
+ : d_state(state), d_im(im)
{
d_nm = NodeManager::currentNM();
d_sm = d_nm->getSkolemManager();
@@ -37,7 +38,7 @@ InferInfo InferenceGenerator::nonNegativeCount(Node n, Node e)
Assert(n.getType().isBag());
Assert(e.getType() == n.getType().getBagElementType());
- InferInfo inferInfo(InferenceId::BAG_NON_NEGATIVE_COUNT);
+ InferInfo inferInfo(d_im, InferenceId::BAG_NON_NEGATIVE_COUNT);
Node count = d_nm->mkNode(kind::BAG_COUNT, e, n);
Node gte = d_nm->mkNode(kind::GEQ, count, d_zero);
@@ -54,7 +55,7 @@ InferInfo InferenceGenerator::mkBag(Node n, Node e)
{
// TODO issue #78: refactor this with BagRewriter
// (=> true (= (bag.count e (bag e c)) c))
- InferInfo inferInfo(InferenceId::BAG_MK_BAG_SAME_ELEMENT);
+ InferInfo inferInfo(d_im, InferenceId::BAG_MK_BAG_SAME_ELEMENT);
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
inferInfo.d_conclusion = count.eqNode(n[1]);
@@ -65,7 +66,7 @@ InferInfo InferenceGenerator::mkBag(Node n, Node e)
// (=>
// true
// (= (bag.count e (bag x c)) (ite (= e x) c 0)))
- InferInfo inferInfo(InferenceId::BAG_MK_BAG);
+ InferInfo inferInfo(d_im, InferenceId::BAG_MK_BAG);
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
Node same = d_nm->mkNode(kind::EQUAL, n[0], e);
@@ -88,7 +89,7 @@ InferInfo InferenceGenerator::bagDisequality(Node n)
Node A = n[0];
Node B = n[1];
- InferInfo inferInfo(InferenceId::BAG_DISEQUALITY);
+ InferInfo inferInfo(d_im, InferenceId::BAG_DISEQUALITY);
TypeNode elementType = A.getType().getBagElementType();
BoundVarManager* bvm = d_nm->getBoundVarManager();
@@ -121,7 +122,7 @@ InferInfo InferenceGenerator::empty(Node n, Node e)
Assert(n.getKind() == kind::EMPTYBAG);
Assert(e.getType() == n.getType().getBagElementType());
- InferInfo inferInfo(InferenceId::BAG_EMPTY);
+ InferInfo inferInfo(d_im, InferenceId::BAG_EMPTY);
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
@@ -137,7 +138,7 @@ InferInfo InferenceGenerator::unionDisjoint(Node n, Node e)
Node A = n[0];
Node B = n[1];
- InferInfo inferInfo(InferenceId::BAG_UNION_DISJOINT);
+ InferInfo inferInfo(d_im, InferenceId::BAG_UNION_DISJOINT);
Node countA = getMultiplicityTerm(e, A);
Node countB = getMultiplicityTerm(e, B);
@@ -159,7 +160,7 @@ InferInfo InferenceGenerator::unionMax(Node n, Node e)
Node A = n[0];
Node B = n[1];
- InferInfo inferInfo(InferenceId::BAG_UNION_MAX);
+ InferInfo inferInfo(d_im, InferenceId::BAG_UNION_MAX);
Node countA = getMultiplicityTerm(e, A);
Node countB = getMultiplicityTerm(e, B);
@@ -182,7 +183,7 @@ InferInfo InferenceGenerator::intersection(Node n, Node e)
Node A = n[0];
Node B = n[1];
- InferInfo inferInfo(InferenceId::BAG_INTERSECTION_MIN);
+ InferInfo inferInfo(d_im, InferenceId::BAG_INTERSECTION_MIN);
Node countA = getMultiplicityTerm(e, A);
Node countB = getMultiplicityTerm(e, B);
@@ -203,7 +204,7 @@ InferInfo InferenceGenerator::differenceSubtract(Node n, Node e)
Node A = n[0];
Node B = n[1];
- InferInfo inferInfo(InferenceId::BAG_DIFFERENCE_SUBTRACT);
+ InferInfo inferInfo(d_im, InferenceId::BAG_DIFFERENCE_SUBTRACT);
Node countA = getMultiplicityTerm(e, A);
Node countB = getMultiplicityTerm(e, B);
@@ -225,7 +226,7 @@ InferInfo InferenceGenerator::differenceRemove(Node n, Node e)
Node A = n[0];
Node B = n[1];
- InferInfo inferInfo(InferenceId::BAG_DIFFERENCE_REMOVE);
+ InferInfo inferInfo(d_im, InferenceId::BAG_DIFFERENCE_REMOVE);
Node countA = getMultiplicityTerm(e, A);
Node countB = getMultiplicityTerm(e, B);
@@ -246,7 +247,7 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e)
Assert(e.getType() == n[0].getType().getBagElementType());
Node A = n[0];
- InferInfo inferInfo(InferenceId::BAG_DUPLICATE_REMOVAL);
+ InferInfo inferInfo(d_im, InferenceId::BAG_DUPLICATE_REMOVAL);
Node countA = getMultiplicityTerm(e, A);
Node skolem = getSkolem(n, inferInfo);
diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h
index 4a852530a..f10a1051f 100644
--- a/src/theory/bags/inference_generator.h
+++ b/src/theory/bags/inference_generator.h
@@ -22,6 +22,7 @@
#include "expr/node.h"
#include "infer_info.h"
+#include "theory/bags/inference_manager.h"
#include "theory/bags/solver_state.h"
namespace CVC4 {
@@ -35,7 +36,7 @@ namespace bags {
class InferenceGenerator
{
public:
- InferenceGenerator(SolverState* state);
+ InferenceGenerator(SolverState* state, InferenceManager* im);
/**
* @param A is a bag of type (Bag E)
@@ -179,6 +180,9 @@ class InferenceGenerator
NodeManager* d_nm;
SkolemManager* d_sm;
SolverState* d_state;
+ /** Pointer to the inference manager */
+ InferenceManager* d_im;
+ /** Commonly used constants */
Node d_true;
Node d_zero;
Node d_one;
diff --git a/src/theory/bags/inference_manager.h b/src/theory/bags/inference_manager.h
index 71a014582..1b132fc37 100644
--- a/src/theory/bags/inference_manager.h
+++ b/src/theory/bags/inference_manager.h
@@ -17,6 +17,7 @@
#ifndef CVC4__THEORY__BAGS__INFERENCE_MANAGER_H
#define CVC4__THEORY__BAGS__INFERENCE_MANAGER_H
+#include "theory/bags/infer_info.h"
#include "theory/bags/solver_state.h"
#include "theory/inference_manager_buffered.h"
diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp
index 6df44295e..48fc38b8f 100644
--- a/src/theory/bags/theory_bags.cpp
+++ b/src/theory/bags/theory_bags.cpp
@@ -31,7 +31,7 @@ TheoryBags::TheoryBags(context::Context* c,
: Theory(THEORY_BAGS, c, u, out, valuation, logicInfo, pnm),
d_state(c, u, valuation),
d_im(*this, d_state, nullptr),
- d_ig(&d_state),
+ d_ig(&d_state, &d_im),
d_notify(*this, d_im),
d_statistics(),
d_rewriter(&d_statistics.d_rewrites),
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback