summaryrefslogtreecommitdiff
path: root/src/theory
diff options
context:
space:
mode:
authormudathirmahgoub <mudathirmahgoub@gmail.com>2021-01-25 14:38:45 -0600
committerGitHub <noreply@github.com>2021-01-25 14:38:45 -0600
commiteaad5bdc7a38fcc38baa0e3b73f6c39a0ec6fb05 (patch)
tree42452e177fa8a24a523ce715aa3a40a99644ab17 /src/theory
parent7f851ea2e2b40f7e5d6e0c0fbe4e9c6ea0450209 (diff)
Refactor bags::SolverState (#5783)
Couple of changes: SolverState now keep tracks of elements per bag instead of per type. bags::InferInfo now stores multiple conclusions (conjuncts). BagSolver applies upward/downward closures for bag elements
Diffstat (limited to 'src/theory')
-rw-r--r--src/theory/bags/bag_solver.cpp91
-rw-r--r--src/theory/bags/bag_solver.h19
-rw-r--r--src/theory/bags/bags_rewriter.h16
-rw-r--r--src/theory/bags/infer_info.cpp28
-rw-r--r--src/theory/bags/infer_info.h13
-rw-r--r--src/theory/bags/inference_generator.cpp122
-rw-r--r--src/theory/bags/inference_generator.h84
-rw-r--r--src/theory/bags/inference_manager.h2
-rw-r--r--src/theory/bags/solver_state.cpp105
-rw-r--r--src/theory/bags/solver_state.h50
-rw-r--r--src/theory/bags/theory_bags.cpp65
11 files changed, 369 insertions, 226 deletions
diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp
index 5621a7c1c..495f73723 100644
--- a/src/theory/bags/bag_solver.cpp
+++ b/src/theory/bags/bag_solver.cpp
@@ -39,25 +39,63 @@ BagSolver::~BagSolver() {}
void BagSolver::postCheck()
{
+ d_state.initialize();
+
+ // At this point, all bag and count representatives should be in the solver
+ // state.
+ for (const Node& bag : d_state.getBags())
+ {
+ // iterate through all bags terms in each equivalent class
+ eq::EqClassIterator it =
+ eq::EqClassIterator(bag, d_state.getEqualityEngine());
+ while (!it.isFinished())
+ {
+ Node n = (*it);
+ Kind k = n.getKind();
+ switch (k)
+ {
+ case kind::MK_BAG: checkMkBag(n); break;
+ case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
+ case kind::UNION_MAX: checkUnionMax(n); break;
+ case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
+ case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
+ default: break;
+ }
+ it++;
+ }
+ }
+
+ // add non negative constraints for all multiplicities
for (const Node& n : d_state.getBags())
{
- Kind k = n.getKind();
- switch (k)
+ for (const Node& e : d_state.getElements(n))
{
- case kind::MK_BAG: checkMkBag(n); break;
- case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
- case kind::UNION_MAX: checkUnionMax(n); break;
- case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
- default: break;
+ checkNonNegativeCountTerms(n, e);
}
}
}
+set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
+{
+ set<Node> elements;
+ const set<Node>& downwards = d_state.getElements(n);
+ const set<Node>& upwards0 = d_state.getElements(n[0]);
+ const set<Node>& upwards1 = d_state.getElements(n[1]);
+
+ set_union(downwards.begin(),
+ downwards.end(),
+ upwards0.begin(),
+ upwards0.end(),
+ inserter(elements, elements.begin()));
+ elements.insert(upwards1.begin(), upwards1.end());
+ return elements;
+}
+
void BagSolver::checkUnionDisjoint(const Node& n)
{
Assert(n.getKind() == UNION_DISJOINT);
- TypeNode elementType = n.getType().getBagElementType();
- for (const Node& e : d_state.getElements(elementType))
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
{
InferenceGenerator ig(&d_state);
InferInfo i = ig.unionDisjoint(n, e);
@@ -69,8 +107,8 @@ void BagSolver::checkUnionDisjoint(const Node& n)
void BagSolver::checkUnionMax(const Node& n)
{
Assert(n.getKind() == UNION_MAX);
- TypeNode elementType = n.getType().getBagElementType();
- for (const Node& e : d_state.getElements(elementType))
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
{
InferenceGenerator ig(&d_state);
InferInfo i = ig.unionMax(n, e);
@@ -82,8 +120,8 @@ void BagSolver::checkUnionMax(const Node& n)
void BagSolver::checkDifferenceSubtract(const Node& n)
{
Assert(n.getKind() == DIFFERENCE_SUBTRACT);
- TypeNode elementType = n.getType().getBagElementType();
- for (const Node& e : d_state.getElements(elementType))
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
{
InferenceGenerator ig(&d_state);
InferInfo i = ig.differenceSubtract(n, e);
@@ -91,11 +129,14 @@ void BagSolver::checkDifferenceSubtract(const Node& n)
Trace("bags::BagSolver::postCheck") << i << endl;
}
}
+
void BagSolver::checkMkBag(const Node& n)
{
Assert(n.getKind() == MK_BAG);
- TypeNode elementType = n.getType().getBagElementType();
- for (const Node& e : d_state.getElements(elementType))
+ Trace("bags::BagSolver::postCheck")
+ << "BagSolver::checkMkBag Elements of " << n
+ << " are: " << d_state.getElements(n) << std::endl;
+ for (const Node& e : d_state.getElements(n))
{
InferenceGenerator ig(&d_state);
InferInfo i = ig.mkBag(n, e);
@@ -103,6 +144,26 @@ void BagSolver::checkMkBag(const Node& n)
Trace("bags::BagSolver::postCheck") << i << endl;
}
}
+void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
+{
+ InferenceGenerator ig(&d_state);
+ InferInfo i = ig.nonNegativeCount(bag, element);
+ i.process(&d_im, true);
+ Trace("bags::BagSolver::postCheck") << i << endl;
+}
+
+void BagSolver::checkDifferenceRemove(const Node& n)
+{
+ Assert(n.getKind() == DIFFERENCE_REMOVE);
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
+ {
+ InferenceGenerator ig(&d_state);
+ InferInfo i = ig.differenceRemove(n, e);
+ i.process(&d_im, true);
+ Trace("bags::BagSolver::postCheck") << i << endl;
+ }
+}
} // namespace bags
} // namespace theory
diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h
index 48583d134..b4b18c00c 100644
--- a/src/theory/bags/bag_solver.h
+++ b/src/theory/bags/bag_solver.h
@@ -41,14 +41,31 @@ class BagSolver
void postCheck();
private:
- /** apply inference rules for MK_BAG operator */
+ /**
+ * apply inference rules for MK_BAG operator.
+ * Example: Suppose n = (bag x c), and we have two count terms (bag.count x n)
+ * and (bag.count y n).
+ * This function will add inferences for the count terms as documented in
+ * InferenceGenerator::mkBag.
+ * Note that element y may not be in bag n. See the documentation of
+ * SolverState::getElements.
+ */
void checkMkBag(const Node& n);
+ /**
+ * @param n is a bag that has the form (op A B)
+ * @return the set union of known elements in (op A B) , A, and B.
+ */
+ std::set<Node> getElementsForBinaryOperator(const Node& n);
/** apply inference rules for union disjoint */
void checkUnionDisjoint(const Node& n);
/** apply inference rules for union max */
void checkUnionMax(const Node& n);
/** apply inference rules for difference subtract */
void checkDifferenceSubtract(const Node& n);
+ /** apply inference rules for difference remove */
+ void checkDifferenceRemove(const Node& n);
+ /** apply non negative constraints for multiplicities */
+ void checkNonNegativeCountTerms(const Node& bag, const Node& element);
/** The solver state object */
SolverState& d_state;
diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h
index fb76fb1c2..48cd9c419 100644
--- a/src/theory/bags/bags_rewriter.h
+++ b/src/theory/bags/bags_rewriter.h
@@ -70,8 +70,8 @@ class BagsRewriter : public TheoryRewriter
/**
* rewrites for n include:
- * - (mkBag x 0) = (emptybag T) where T is the type of x
- * - (mkBag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a
+ * - (bag x 0) = (emptybag T) where T is the type of x
+ * - (bag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a
* constant
* - otherwise = n
*/
@@ -87,7 +87,7 @@ class BagsRewriter : public TheoryRewriter
/**
* rewrites for n include:
- * - (duplicate_removal (mkBag x n)) = (mkBag x 1)
+ * - (duplicate_removal (bag x n)) = (bag x 1)
* where n is a positive constant
*/
BagsRewriteResponse rewriteDuplicateRemoval(const TNode& n) const;
@@ -171,13 +171,13 @@ class BagsRewriter : public TheoryRewriter
BagsRewriteResponse rewriteDifferenceRemove(const TNode& n) const;
/**
* rewrites for n include:
- * - (bag.choose (mkBag x c)) = x where c is a constant > 0
+ * - (bag.choose (bag x c)) = x where c is a constant > 0
* - otherwise = n
*/
BagsRewriteResponse rewriteChoose(const TNode& n) const;
/**
* rewrites for n include:
- * - (bag.card (mkBag x c)) = c where c is a constant > 0
+ * - (bag.card (bag x c)) = c where c is a constant > 0
* - (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
* - otherwise = n
*/
@@ -185,19 +185,19 @@ class BagsRewriter : public TheoryRewriter
/**
* rewrites for n include:
- * - (bag.is_singleton (mkBag x c)) = (c == 1)
+ * - (bag.is_singleton (bag x c)) = (c == 1)
*/
BagsRewriteResponse rewriteIsSingleton(const TNode& n) const;
/**
* rewrites for n include:
- * - (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
+ * - (bag.from_set (singleton (singleton_op Int) x)) = (bag x 1)
*/
BagsRewriteResponse rewriteFromSet(const TNode& n) const;
/**
* rewrites for n include:
- * - (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
+ * - (bag.to_set (bag x n)) = (singleton (singleton_op T) x)
* where n is a positive constant and T is the type of the bag's elements
*/
BagsRewriteResponse rewriteToSet(const TNode& n) const;
diff --git a/src/theory/bags/infer_info.cpp b/src/theory/bags/infer_info.cpp
index 1244a43ac..5b3274617 100644
--- a/src/theory/bags/infer_info.cpp
+++ b/src/theory/bags/infer_info.cpp
@@ -25,6 +25,8 @@ const char* toString(Inference i)
switch (i)
{
case Inference::NONE: return "NONE";
+ case Inference::BAG_NON_NEGATIVE_COUNT: return "BAG_NON_NEGATIVE_COUNT";
+ case Inference::BAG_MK_BAG_SAME_ELEMENT: return "BAG_MK_BAG_SAME_ELEMENT";
case Inference::BAG_MK_BAG: return "BAG_MK_BAG";
case Inference::BAG_EQUALITY: return "BAG_EQUALITY";
case Inference::BAG_DISEQUALITY: return "BAG_DISEQUALITY";
@@ -62,9 +64,19 @@ bool InferInfo::process(TheoryInferenceManager* im, bool asLemma)
if (asLemma)
{
TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
- return im->trustedLemma(trustedLemma);
+ im->trustedLemma(trustedLemma);
}
- Unimplemented();
+ else
+ {
+ Unimplemented();
+ }
+ for (const auto& pair : d_skolems)
+ {
+ Node n = pair.first.eqNode(pair.second);
+ TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr);
+ im->trustedLemma(trustedLemma);
+ }
+ return true;
}
bool InferInfo::isTrivial() const
@@ -87,21 +99,15 @@ bool InferInfo::isFact() const
return !atom.isConst() && atom.getKind() != kind::OR;
}
-Node InferInfo::getPremises() const
-{
- // d_noExplain is a subset of d_ant
- NodeManager* nm = NodeManager::currentNM();
- return nm->mkAnd(d_premises);
-}
-
std::ostream& operator<<(std::ostream& out, const InferInfo& ii)
{
- out << "(infer " << ii.d_id << " " << ii.d_conclusion << std::endl;
+ out << "(infer :id " << ii.d_id << std::endl;
+ out << ":conclusion " << ii.d_conclusion << std::endl;
if (!ii.d_premises.empty())
{
out << " :premise (" << ii.d_premises << ")" << std::endl;
}
-
+ out << ":skolems " << ii.d_skolems << std::endl;
out << ")";
return out;
}
diff --git a/src/theory/bags/infer_info.h b/src/theory/bags/infer_info.h
index 3edbef737..ecfc354d1 100644
--- a/src/theory/bags/infer_info.h
+++ b/src/theory/bags/infer_info.h
@@ -33,6 +33,8 @@ namespace bags {
enum class Inference : uint32_t
{
NONE,
+ BAG_NON_NEGATIVE_COUNT,
+ BAG_MK_BAG_SAME_ELEMENT,
BAG_MK_BAG,
BAG_EQUALITY,
BAG_DISEQUALITY,
@@ -81,7 +83,7 @@ class InferInfo : public TheoryInference
bool process(TheoryInferenceManager* im, bool asLemma) override;
/** The inference identifier */
Inference d_id;
- /** The conclusion */
+ /** The conclusions */
Node d_conclusion;
/**
* The premise(s) of the inference, interpreted conjunctively. These are
@@ -90,11 +92,10 @@ class InferInfo : public TheoryInference
std::vector<Node> d_premises;
/**
- * A list of new skolems introduced as a result of this inference. They
- * are mapped to by a length status, indicating the length constraint that
- * can be assumed for them.
+ * A map of nodes to their skolem variables introduced as a result of this
+ * inference.
*/
- std::vector<Node> d_newSkolem;
+ std::map<Node, Node> d_skolems;
/** Is this infer info trivial? True if d_conc is true. */
bool isTrivial() const;
/**
@@ -108,8 +109,6 @@ class InferInfo : public TheoryInference
* engine with no new external premises (d_noExplain).
*/
bool isFact() const;
- /** Get premises */
- Node getPremises() const;
};
/**
diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp
index 759ea1f0c..7ef126911 100644
--- a/src/theory/bags/inference_generator.cpp
+++ b/src/theory/bags/inference_generator.cpp
@@ -32,18 +32,33 @@ InferenceGenerator::InferenceGenerator(SolverState* state) : d_state(state)
d_one = d_nm->mkConst(Rational(1));
}
+InferInfo InferenceGenerator::nonNegativeCount(Node n, Node e)
+{
+ Assert(n.getType().isBag());
+ Assert(e.getType() == n.getType().getBagElementType());
+
+ InferInfo inferInfo;
+ inferInfo.d_id = Inference::BAG_NON_NEGATIVE_COUNT;
+ Node count = d_nm->mkNode(kind::BAG_COUNT, e, n);
+
+ Node gte = d_nm->mkNode(kind::GEQ, count, d_zero);
+ inferInfo.d_conclusion = gte;
+ return inferInfo;
+}
+
InferInfo InferenceGenerator::mkBag(Node n, Node e)
{
Assert(n.getKind() == kind::MK_BAG);
Assert(e.getType() == n.getType().getBagElementType());
InferInfo inferInfo;
- inferInfo.d_id = Inference::BAG_MK_BAG;
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
if (n[0] == e)
{
- // TODO: refactor this with the rewriter
+ // TODO issue #78: refactor this with BagRewriter
// (=> true (= (bag.count e (bag e c)) c))
+ inferInfo.d_id = Inference::BAG_MK_BAG_SAME_ELEMENT;
inferInfo.d_conclusion = count.eqNode(n[1]);
}
else
@@ -51,7 +66,7 @@ InferInfo InferenceGenerator::mkBag(Node n, Node e)
// (=>
// true
// (= (bag.count e (bag x c)) (ite (= e x) c 0)))
-
+ inferInfo.d_id = Inference::BAG_MK_BAG;
Node same = d_nm->mkNode(kind::EQUAL, n[0], e);
Node ite = d_nm->mkNode(kind::ITE, same, n[1], d_zero);
Node equal = count.eqNode(ite);
@@ -60,30 +75,12 @@ InferInfo InferenceGenerator::mkBag(Node n, Node e)
return inferInfo;
}
-InferInfo InferenceGenerator::bagEquality(Node n, Node e)
-{
- Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag());
- Assert(e.getType() == n[0].getType().getBagElementType());
-
- Node A = n[0];
- Node B = n[1];
- InferInfo inferInfo;
- inferInfo.d_id = Inference::BAG_EQUALITY;
- inferInfo.d_premises.push_back(n);
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
-
- Node equal = countA.eqNode(countB);
- inferInfo.d_conclusion = equal;
- return inferInfo;
-}
-
struct BagsDeqAttributeId
{
};
typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
-InferInfo InferenceGenerator::bagDisequality(Node n)
+InferInfo InferenceGenerator::bagDisequality(Node n, Node reason)
{
Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL);
Assert(n[0][0].getType().isBag());
@@ -93,22 +90,19 @@ InferInfo InferenceGenerator::bagDisequality(Node n)
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DISEQUALITY;
+ inferInfo.d_premises.push_back(reason);
TypeNode elementType = A.getType().getBagElementType();
-
BoundVarManager* bvm = d_nm->getBoundVarManager();
Node element = bvm->mkBoundVar<BagsDeqAttribute>(n, elementType);
- SkolemManager* sm = d_nm->getSkolemManager();
Node skolem =
- sm->mkSkolem(element,
- n,
- "bag_disequal",
- "an extensional lemma for disequality of two bags");
+ d_sm->mkSkolem(element,
+ n,
+ "bag_disequal",
+ "an extensional lemma for disequality of two bags");
- inferInfo.d_newSkolem.push_back(skolem);
-
- Node countA = getMultiplicitySkolem(skolem, A, inferInfo);
- Node countB = getMultiplicitySkolem(skolem, B, inferInfo);
+ Node countA = getMultiplicityTerm(skolem, A);
+ Node countB = getMultiplicityTerm(skolem, B);
Node disEqual = countA.eqNode(countB).notNode();
@@ -117,13 +111,20 @@ InferInfo InferenceGenerator::bagDisequality(Node n)
return inferInfo;
}
+Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo)
+{
+ Node skolem = d_sm->mkPurifySkolem(n, "skolem_bag", "skolem bag");
+ inferInfo.d_skolems[n] = skolem;
+ return skolem;
+}
+
InferInfo InferenceGenerator::bagEmpty(Node e)
{
EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType()));
Node empty = d_nm->mkConst(emptyBag);
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_EMPTY;
- Node count = getMultiplicitySkolem(e, empty, inferInfo);
+ Node count = getMultiplicityTerm(e, empty);
Node equal = count.eqNode(d_zero);
inferInfo.d_conclusion = equal;
@@ -140,9 +141,11 @@ InferInfo InferenceGenerator::unionDisjoint(Node n, Node e)
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_UNION_DISJOINT;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node countB = getMultiplicityTerm(e, B);
+
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node sum = d_nm->mkNode(kind::PLUS, countA, countB);
Node equal = count.eqNode(sum);
@@ -161,9 +164,11 @@ InferInfo InferenceGenerator::unionMax(Node n, Node e)
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_UNION_MAX;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node countB = getMultiplicityTerm(e, B);
+
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node gt = d_nm->mkNode(kind::GT, countA, countB);
Node max = d_nm->mkNode(kind::ITE, gt, countA, countB);
@@ -183,9 +188,10 @@ InferInfo InferenceGenerator::intersection(Node n, Node e)
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_INTERSECTION_MIN;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node countB = getMultiplicityTerm(e, B);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node lt = d_nm->mkNode(kind::LT, countA, countB);
Node min = d_nm->mkNode(kind::ITE, lt, countA, countB);
@@ -204,9 +210,10 @@ InferInfo InferenceGenerator::differenceSubtract(Node n, Node e)
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DIFFERENCE_SUBTRACT;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node countB = getMultiplicityTerm(e, B);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node subtract = d_nm->mkNode(kind::MINUS, countA, countB);
Node gte = d_nm->mkNode(kind::GEQ, countA, countB);
@@ -226,9 +233,11 @@ InferInfo InferenceGenerator::differenceRemove(Node n, Node e)
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DIFFERENCE_REMOVE;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node countB = getMultiplicitySkolem(e, B, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node countB = getMultiplicityTerm(e, B);
+
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node notInB = d_nm->mkNode(kind::EQUAL, countB, d_zero);
Node difference = d_nm->mkNode(kind::ITE, notInB, countA, d_zero);
@@ -246,8 +255,9 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e)
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DUPLICATE_REMOVAL;
- Node countA = getMultiplicitySkolem(e, A, inferInfo);
- Node count = getMultiplicitySkolem(e, n, inferInfo);
+ Node countA = getMultiplicityTerm(e, A);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
Node gte = d_nm->mkNode(kind::GEQ, countA, d_one);
Node ite = d_nm->mkNode(kind::ITE, gte, d_one, d_zero);
@@ -256,16 +266,10 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e)
return inferInfo;
}
-Node InferenceGenerator::getMultiplicitySkolem(Node element,
- Node bag,
- InferInfo& inferInfo)
+Node InferenceGenerator::getMultiplicityTerm(Node element, Node bag)
{
Node count = d_nm->mkNode(kind::BAG_COUNT, element, bag);
- Node skolem = d_state->registerBagElement(count);
- eq::EqualityEngine* ee = d_state->getEqualityEngine();
- ee->assertEquality(skolem.eqNode(count), true, d_nm->mkConst(true));
- inferInfo.d_newSkolem.push_back(skolem);
- return skolem;
+ return count;
}
} // namespace bags
diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h
index b56997088..9eee46e43 100644
--- a/src/theory/bags/inference_generator.h
+++ b/src/theory/bags/inference_generator.h
@@ -38,33 +38,38 @@ class InferenceGenerator
InferenceGenerator(SolverState* state);
/**
- * @param n is (bag x c) of type (Bag E)
+ * @param A is a bag of type (Bag E)
* @param e is a node of type E
* @return an inference that represents the following implication
* (=>
* true
- * (= (bag.count e (bag x c)) (ite (= e x) c 0)))
+ * (>= (bag.count e A) 0)
*/
- InferInfo mkBag(Node n, Node e);
+ InferInfo nonNegativeCount(Node n, Node e);
/**
- * @param n is (= A B) where A, B are bags of type (Bag E)
- * @param e is a node of Type E
+ * @param n is (bag x c) of type (Bag E)
+ * @param e is a node of type E
* @return an inference that represents the following implication
* (=>
- * (= A B)
- * (= (count e A) (count e B)))
+ * true
+ * (= (bag.count e skolem) c))
+ * if e is exactly node x. Node skolem is a fresh variable equals (bag x c).
+ * Otherwise the following inference is returned
+ * (=>
+ * true
+ * (= (bag.count e skolem) (ite (= e x) c 0)))
*/
- InferInfo bagEquality(Node n, Node e);
+ InferInfo mkBag(Node n, Node e);
/**
* @param n is (not (= A B)) where A, B are bags of type (Bag E)
* @return an inference that represents the following implication
* (=>
* (not (= A B))
* (not (= (count e A) (count e B))))
- * where e is a fresh skolem of type E
+ * where e is a fresh skolem of type E.
*/
- InferInfo bagDisequality(Node n);
+ InferInfo bagDisequality(Node n, Node reason);
/**
* @param e is a node of Type E
* @return an inference that represents the following implication
@@ -79,10 +84,9 @@ class InferenceGenerator
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e k_{(union_disjoint A B)})
+ * (= (count e skolem)
* (+ (count e A) (count e B))))
- * where k_{(union_disjoint A B)} is a unique purification skolem
- * for (union_disjoint A B)
+ * where skolem is a fresh variable equals (union_disjoint A B)
*/
InferInfo unionDisjoint(Node n, Node e);
/**
@@ -91,11 +95,13 @@ class InferenceGenerator
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e (union_max A B))
+ * (=
+ * (count e skolem)
* (ite
- * (> (count e A) (count e B))
- * (count e A)
- * (count e B)))))
+ * (> (count e A) (count e B))
+ * (count e A)
+ * (count e B)))))
+ * where skolem is a fresh variable equals (union_max A B)
*/
InferInfo unionMax(Node n, Node e);
/**
@@ -104,11 +110,13 @@ class InferenceGenerator
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e (intersection_min A B))
+ * (=
+ * (count e skolem)
* (ite(
- * (< (count e A) (count e B))
- * (count e A)
- * (count e B)))))
+ * (< (count e A) (count e B))
+ * (count e A)
+ * (count e B)))))
+ * where skolem is a fresh variable equals (intersection_min A B)
*/
InferInfo intersection(Node n, Node e);
/**
@@ -117,11 +125,13 @@ class InferenceGenerator
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e (difference_subtract A B))
+ * (=
+ * (count e skolem)
* (ite
- * (>= (count e A) (count e B))
- * (- (count e A) (count e B))
- * 0))))
+ * (>= (count e A) (count e B))
+ * (- (count e A) (count e B))
+ * 0))))
+ * where skolem is a fresh variable equals (difference_subtract A B)
*/
InferInfo differenceSubtract(Node n, Node e);
/**
@@ -130,11 +140,13 @@ class InferenceGenerator
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e (difference_remove A B))
+ * (=
+ * (count e skolem)
* (ite
- * (= (count e B) 0)
- * (count e A)
- * 0))))
+ * (= (count e B) 0)
+ * (count e A)
+ * 0))))
+ * where skolem is a fresh variable equals (difference_remove A B)
*/
InferInfo differenceRemove(Node n, Node e);
/**
@@ -143,20 +155,24 @@ class InferenceGenerator
* @return an inference that represents the following implication
* (=>
* true
- * (= (count e (duplicate_removal A))
- * (ite (>= (count e A) 1) 1 0))))
+ * (=
+ * (count e skolem)
+ * (ite (>= (count e A) 1) 1 0))))
+ * where skolem is a fresh variable equals (duplicate_removal A)
*/
InferInfo duplicateRemoval(Node n, Node e);
/**
* @param element of type T
* @param bag of type (bag T)
- * @param inferInfo to store new skolem
- * @return a skolem for (bag.count element bag)
+ * @return a count term (bag.count element bag)
*/
- Node getMultiplicitySkolem(Node element, Node bag, InferInfo& inferInfo);
+ Node getMultiplicityTerm(Node element, Node bag);
private:
+ /** generate skolem variable for node n and add it to inferInfo */
+ Node getSkolem(Node& n, InferInfo& inferInfo);
+
NodeManager* d_nm;
SkolemManager* d_sm;
SolverState* d_state;
diff --git a/src/theory/bags/inference_manager.h b/src/theory/bags/inference_manager.h
index 67025548c..71a014582 100644
--- a/src/theory/bags/inference_manager.h
+++ b/src/theory/bags/inference_manager.h
@@ -45,7 +45,7 @@ class InferenceManager : public InferenceManagerBuffered
* process the pending lemmas and then the pending phase requirements.
* Notice that we process the pending lemmas even if there were facts.
*/
- // TODO: refactor this before merge with theory of strings
+ // TODO issue #78: refactor this with theory of strings
void doPending();
private:
diff --git a/src/theory/bags/solver_state.cpp b/src/theory/bags/solver_state.cpp
index 744f6de9f..9bcb6ae3c 100644
--- a/src/theory/bags/solver_state.cpp
+++ b/src/theory/bags/solver_state.cpp
@@ -33,52 +33,89 @@ SolverState::SolverState(context::Context* c,
{
d_true = NodeManager::currentNM()->mkConst(true);
d_false = NodeManager::currentNM()->mkConst(false);
+ d_nm = NodeManager::currentNM();
}
-struct BagsCountAttributeId
+void SolverState::registerBag(TNode n)
{
-};
-typedef expr::Attribute<BagsCountAttributeId, Node> BagsCountAttribute;
-
-void SolverState::registerClass(TNode n)
-{
- TypeNode t = n.getType();
- if (!t.isBag())
- {
- return;
- }
+ Assert(n.getType().isBag());
d_bags.insert(n);
}
-Node SolverState::registerBagElement(TNode n)
+void SolverState::registerCountTerm(TNode n)
{
Assert(n.getKind() == BAG_COUNT);
- Node element = n[0];
- TypeNode elementType = element.getType();
- Node bag = n[1];
- d_elements[elementType].insert(element);
- NodeManager* nm = NodeManager::currentNM();
- BoundVarManager* bvm = nm->getBoundVarManager();
- Node multiplicity = bvm->mkBoundVar<BagsCountAttribute>(n, nm->integerType());
- Node equal = n.eqNode(multiplicity);
- SkolemManager* sm = nm->getSkolemManager();
- Node skolem = sm->mkSkolem(
- multiplicity,
- equal,
- "bag_multiplicity",
- "an extensional lemma for multiplicity of an element in a bag");
- d_count[bag][element] = skolem;
- Trace("bags::SolverState::registerBagElement")
- << "New skolem: " << skolem << " for " << n << std::endl;
-
- return skolem;
+ Node element = getRepresentative(n[0]);
+ Node bag = getRepresentative(n[1]);
+ d_bagElements[bag].insert(element);
+}
+
+const std::set<Node>& SolverState::getBags() { return d_bags; }
+
+const std::set<Node>& SolverState::getElements(Node B)
+{
+ Node bag = getRepresentative(B);
+ return d_bagElements[B];
+}
+
+void SolverState::reset()
+{
+ d_bagElements.clear();
+ d_bags.clear();
}
-std::set<Node>& SolverState::getBags() { return d_bags; }
+void SolverState::initialize()
+{
+ reset();
+ collectBagsAndCountTerms();
+}
+
+void SolverState::collectBagsAndCountTerms()
+{
+ Trace("SolverState::collectBagsAndCountTerms")
+ << "SolverState::collectBagsAndCountTerms start" << endl;
+ eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee);
+ while (!repIt.isFinished())
+ {
+ Node eqc = (*repIt);
+ Trace("SolverState::collectBagsAndCountTerms")
+ << "[" << eqc << "]: " << endl;
+
+ if (eqc.getType().isBag())
+ {
+ registerBag(eqc);
+ }
-std::set<Node>& SolverState::getElements(TypeNode t) { return d_elements[t]; }
+ eq::EqClassIterator it = eq::EqClassIterator(eqc, d_ee);
+ while (!it.isFinished())
+ {
+ Node n = (*it);
+ Kind k = n.getKind();
+ if (k == MK_BAG)
+ {
+ // for terms (bag x c) we need to store x by registering the count term
+ // (bag.count x (bag x c))
+ Node count = d_nm->mkNode(BAG_COUNT, n[0], n);
+ registerCountTerm(count);
+ Trace("SolverState::collectBagsAndCountTerms")
+ << "registered " << count << endl;
+ }
+ if (k == BAG_COUNT)
+ {
+ // this takes care of all count terms in each equivalent class
+ registerCountTerm(n);
+ Trace("SolverState::collectBagsAndCountTerms")
+ << "registered " << n << endl;
+ }
+ ++it;
+ }
-std::map<Node, Node>& SolverState::getBagElements(Node B) { return d_count[B]; }
+ ++repIt;
+ }
+
+ Trace("SolverState::collectBagsAndCountTerms")
+ << "SolverState::collectBagsAndCountTerms end" << endl;
+}
} // namespace bags
} // namespace theory
diff --git a/src/theory/bags/solver_state.h b/src/theory/bags/solver_state.h
index 8d70ee8f7..175317529 100644
--- a/src/theory/bags/solver_state.h
+++ b/src/theory/bags/solver_state.h
@@ -31,24 +31,52 @@ class SolverState : public TheoryState
public:
SolverState(context::Context* c, context::UserContext* u, Valuation val);
- void registerClass(TNode n);
+ /**
+ * This function adds the bag representative n to the set d_bags if it is not
+ * already there. This function is called during postCheck to collect bag
+ * terms in the equality engine. See the documentation of
+ * @link SolverState::collectBagsAndCountTerms
+ */
+ void registerBag(TNode n);
- Node registerBagElement(TNode n);
-
- std::set<Node>& getBags();
-
- std::set<Node>& getElements(TypeNode t);
-
- std::map<Node, Node>& getBagElements(Node B);
+ /**
+ * @param n has the form (bag.count e A)
+ * @pre bag A needs is already registered using registerBag(A)
+ * @return a unique skolem for (bag.count e A)
+ */
+ void registerCountTerm(TNode n);
+ /** get all bag terms that are representatives in the equality engine.
+ * This function is valid after the current solver is initialized during
+ * postCheck. See SolverState::initialize and BagSolver::postCheck
+ */
+ const std::set<Node>& getBags();
+ /**
+ * @pre B is a registered bag
+ * @return all elements associated with bag B so far
+ * Note that associated elements are not necessarily elements in B
+ * Example:
+ * (assert (= 0 (bag.count x B)))
+ * element x is associated with bag B, albeit x is definitely not in B.
+ */
+ const std::set<Node>& getElements(Node B);
+ /** initialize bag and count terms */
+ void initialize();
private:
+ /** clear all bags data structures */
+ void reset();
+ /** collect bags' representatives and all count terms.
+ * This function is called during postCheck */
+ void collectBagsAndCountTerms();
/** constants */
Node d_true;
Node d_false;
+ /** node manager for this solver state */
+ NodeManager* d_nm;
+ /** collection of bag representatives */
std::set<Node> d_bags;
- std::map<TypeNode, std::set<Node>> d_elements;
- /** bag -> element -> multiplicity */
- std::map<Node, std::map<Node, Node>> d_count;
+ /** bag -> associated elements */
+ std::map<Node, std::set<Node>> d_bagElements;
}; /* class SolverState */
} // namespace bags
diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp
index 21a9d0e53..153e9017d 100644
--- a/src/theory/bags/theory_bags.cpp
+++ b/src/theory/bags/theory_bags.cpp
@@ -78,22 +78,22 @@ void TheoryBags::finishInit()
void TheoryBags::postCheck(Effort effort)
{
d_im.doPendingFacts();
- // TODO: clean this before merge Assert(d_strat.isStrategyInit());
+ // TODO issue #78: add Assert(d_strat.isStrategyInit());
if (!d_state.isInConflict() && !d_valuation.needCheck())
- // TODO: clean this before merge && d_strat.hasStrategyEffort(e))
+ // TODO issue #78: add && d_strat.hasStrategyEffort(e))
{
Trace("bags::TheoryBags::postCheck") << "effort: " << std::endl;
- // TODO: clean this before merge ++(d_statistics.d_checkRuns);
+ // TODO issue #78: add ++(d_statistics.d_checkRuns);
bool sentLemma = false;
bool hadPending = false;
Trace("bags-check") << "Full effort check..." << std::endl;
do
{
d_im.reset();
- // TODO: clean this before merge ++(d_statistics.d_strategyRuns);
+ // TODO issue #78: add ++(d_statistics.d_strategyRuns);
Trace("bags-check") << " * Run strategy..." << std::endl;
- // TODO: clean this before merge runStrategy(e);
+ // TODO issue #78: add runStrategy(e);
d_solver.postCheck();
@@ -153,14 +153,22 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
continue;
}
Node r = d_state.getRepresentative(n);
- std::map<Node, Node> elements = d_state.getBagElements(r);
+ std::set<Node> solverElements = d_state.getElements(r);
+ std::set<Node> elements;
+ // only consider terms in termSet and ignore other elements in the solver
+ std::set_intersection(termSet.begin(),
+ termSet.end(),
+ solverElements.begin(),
+ solverElements.end(),
+ std::inserter(elements, elements.begin()));
Trace("bags-model") << "Elements of bag " << n << " are: " << std::endl
<< elements << std::endl;
std::map<Node, Node> elementReps;
- for (std::pair<Node, Node> pair : elements)
+ for (const Node& e : elements)
{
- Node key = d_state.getRepresentative(pair.first);
- Node value = d_state.getRepresentative(pair.second);
+ Node key = d_state.getRepresentative(e);
+ Node countTerm = NodeManager::currentNM()->mkNode(BAG_COUNT, e, r);
+ Node value = d_state.getRepresentative(countTerm);
elementReps[key] = value;
}
Node rep = NormalForm::constructBagFromElements(tn, elementReps);
@@ -211,38 +219,7 @@ void TheoryBags::presolve() {}
/**************************** eq::NotifyClass *****************************/
-void TheoryBags::eqNotifyNewClass(TNode n)
-{
- Kind k = n.getKind();
- d_state.registerClass(n);
- if (n.getKind() == MK_BAG)
- {
- // TODO: refactor this before merge
- /*
- * (bag x m) generates the lemma (and (= s (count x (bag x m))) (= s m))
- * where s is a fresh skolem variable
- */
- NodeManager* nm = NodeManager::currentNM();
- Node count = nm->mkNode(BAG_COUNT, n[0], n);
- Node skolem = d_state.registerBagElement(count);
- Node countSkolem = count.eqNode(skolem);
- Node skolemMultiplicity = n[1].eqNode(skolem);
- Node lemma = countSkolem.andNode(skolemMultiplicity);
- TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
- d_im.trustedLemma(trustedLemma);
- }
- if (k == BAG_COUNT)
- {
- /*
- * (count x A) generates the lemma (= s (count x A))
- * where s is a fresh skolem variable
- */
- Node skolem = d_state.registerBagElement(n);
- Node lemma = n.eqNode(skolem);
- TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
- d_im.trustedLemma(trustedLemma);
- }
-}
+void TheoryBags::eqNotifyNewClass(TNode n) {}
void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {}
@@ -251,10 +228,8 @@ void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason)
TypeNode t1 = n1.getType();
if (t1.isBag())
{
- InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode());
- Node lemma = reason.impNode(info.d_conclusion);
- TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
- d_im.trustedLemma(trustedLemma);
+ InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode(), reason);
+ info.process(d_inferManager, true);
}
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback