summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormudathirmahgoub <mudathirmahgoub@gmail.com>2021-01-29 15:44:28 -0600
committerGitHub <noreply@github.com>2021-01-29 15:44:28 -0600
commit50c3dee5c8a4855023df826e1a733ea3c6076774 (patch)
tree3dd574a7e6153c9225c6abcde9085d06de057f6f
parentce1b2f2fb06150599c231bf0d59b52a07e74c3f5 (diff)
Add bag inferences for operators: intersection, duplicate_removal, and empty bags (#5832)
This PR adds inferences for operators: intersection, duplicate_removal, and empty bags during post check. It also fixes a bug in SolverState::getElements
-rw-r--r--src/theory/bags/bag_solver.cpp78
-rw-r--r--src/theory/bags/bag_solver.h11
-rw-r--r--src/theory/bags/infer_info.cpp3
-rw-r--r--src/theory/bags/inference_generator.cpp22
-rw-r--r--src/theory/bags/inference_generator.h11
-rw-r--r--src/theory/bags/solver_state.cpp33
-rw-r--r--src/theory/bags/solver_state.h14
-rw-r--r--src/theory/bags/theory_bags.cpp11
-rw-r--r--test/regress/CMakeLists.txt5
-rw-r--r--test/regress/regress1/bags/duplicate_removal1.smt28
-rw-r--r--test/regress/regress1/bags/duplicate_removal2.smt28
-rw-r--r--test/regress/regress1/bags/emptybag1.smt210
-rw-r--r--test/regress/regress1/bags/intersection_min1.smt210
-rw-r--r--test/regress/regress1/bags/intersection_min2.smt29
14 files changed, 179 insertions, 54 deletions
diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp
index 495f73723..bdd4a9b30 100644
--- a/src/theory/bags/bag_solver.cpp
+++ b/src/theory/bags/bag_solver.cpp
@@ -27,7 +27,7 @@ namespace theory {
namespace bags {
BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr)
- : d_state(s), d_im(im), d_termReg(tr)
+ : d_state(s), d_ig(&d_state), d_im(im), d_termReg(tr)
{
d_zero = NodeManager::currentNM()->mkConst(Rational(0));
d_one = NodeManager::currentNM()->mkConst(Rational(1));
@@ -41,6 +41,8 @@ void BagSolver::postCheck()
{
d_state.initialize();
+ checkDisequalBagTerms();
+
// At this point, all bag and count representatives should be in the solver
// state.
for (const Node& bag : d_state.getBags())
@@ -54,11 +56,14 @@ void BagSolver::postCheck()
Kind k = n.getKind();
switch (k)
{
+ case kind::EMPTYBAG: checkEmpty(n); break;
case kind::MK_BAG: checkMkBag(n); break;
case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
case kind::UNION_MAX: checkUnionMax(n); break;
+ case kind::INTERSECTION_MIN: checkIntersectionMin(n); break;
case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
+ case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
default: break;
}
it++;
@@ -91,16 +96,24 @@ set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
return elements;
}
+void BagSolver::checkEmpty(const Node& n)
+{
+ Assert(n.getKind() == EMPTYBAG);
+ for (const Node& e : d_state.getElements(n))
+ {
+ InferInfo i = d_ig.empty(n, e);
+ i.process(&d_im, true);
+ }
+}
+
void BagSolver::checkUnionDisjoint(const Node& n)
{
Assert(n.getKind() == UNION_DISJOINT);
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.unionDisjoint(n, e);
+ InferInfo i = d_ig.unionDisjoint(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
@@ -110,10 +123,19 @@ void BagSolver::checkUnionMax(const Node& n)
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.unionMax(n, e);
+ InferInfo i = d_ig.unionMax(n, e);
+ i.process(&d_im, true);
+ }
+}
+
+void BagSolver::checkIntersectionMin(const Node& n)
+{
+ Assert(n.getKind() == INTERSECTION_MIN);
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
+ {
+ InferInfo i = d_ig.intersection(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
@@ -123,10 +145,8 @@ void BagSolver::checkDifferenceSubtract(const Node& n)
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.differenceSubtract(n, e);
+ InferInfo i = d_ig.differenceSubtract(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
@@ -138,18 +158,14 @@ void BagSolver::checkMkBag(const Node& n)
<< " are: " << d_state.getElements(n) << std::endl;
for (const Node& e : d_state.getElements(n))
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.mkBag(n, e);
+ InferInfo i = d_ig.mkBag(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.nonNegativeCount(bag, element);
+ InferInfo i = d_ig.nonNegativeCount(bag, element);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
void BagSolver::checkDifferenceRemove(const Node& n)
@@ -158,10 +174,34 @@ void BagSolver::checkDifferenceRemove(const Node& n)
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.differenceRemove(n, e);
+ InferInfo i = d_ig.differenceRemove(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
+ }
+}
+
+void BagSolver::checkDuplicateRemoval(Node n)
+{
+ Assert(n.getKind() == DUPLICATE_REMOVAL);
+ set<Node> elements;
+ const set<Node>& downwards = d_state.getElements(n);
+ const set<Node>& upwards = d_state.getElements(n[0]);
+
+ elements.insert(downwards.begin(), downwards.end());
+ elements.insert(upwards.begin(), upwards.end());
+
+ for (const Node& e : elements)
+ {
+ InferInfo i = d_ig.duplicateRemoval(n, e);
+ i.process(&d_im, true);
+ }
+}
+
+void BagSolver::checkDisequalBagTerms()
+{
+ for (const Node& n : d_state.getDisequalBagTerms())
+ {
+ InferInfo info = d_ig.bagDisequality(n);
+ info.process(&d_im, true);
}
}
diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h
index b4b18c00c..b19e1f11e 100644
--- a/src/theory/bags/bag_solver.h
+++ b/src/theory/bags/bag_solver.h
@@ -20,6 +20,7 @@
#include "context/cdhashset.h"
#include "context/cdlist.h"
#include "theory/bags/infer_info.h"
+#include "theory/bags/inference_generator.h"
#include "theory/bags/inference_manager.h"
#include "theory/bags/normal_form.h"
#include "theory/bags/solver_state.h"
@@ -41,6 +42,8 @@ class BagSolver
void postCheck();
private:
+ /** apply inference rules for empty bags */
+ void checkEmpty(const Node& n);
/**
* apply inference rules for MK_BAG operator.
* Example: Suppose n = (bag x c), and we have two count terms (bag.count x n)
@@ -60,15 +63,23 @@ class BagSolver
void checkUnionDisjoint(const Node& n);
/** apply inference rules for union max */
void checkUnionMax(const Node& n);
+ /** apply inference rules for intersection_min operator */
+ void checkIntersectionMin(const Node& n);
/** apply inference rules for difference subtract */
void checkDifferenceSubtract(const Node& n);
/** apply inference rules for difference remove */
void checkDifferenceRemove(const Node& n);
+ /** apply inference rules for duplicate removal operator */
+ void checkDuplicateRemoval(Node n);
/** apply non negative constraints for multiplicities */
void checkNonNegativeCountTerms(const Node& bag, const Node& element);
+ /** apply inference rules for disequal bag terms */
+ void checkDisequalBagTerms();
/** The solver state object */
SolverState& d_state;
+ /** The inference generator object*/
+ InferenceGenerator d_ig;
/** Reference to the inference manager for the theory of bags */
InferenceManager& d_im;
/** Reference to the term registry of theory of bags */
diff --git a/src/theory/bags/infer_info.cpp b/src/theory/bags/infer_info.cpp
index 5b3274617..9bf546af1 100644
--- a/src/theory/bags/infer_info.cpp
+++ b/src/theory/bags/infer_info.cpp
@@ -76,6 +76,9 @@ bool InferInfo::process(TheoryInferenceManager* im, bool asLemma)
TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr);
im->trustedLemma(trustedLemma);
}
+
+ Trace("bags::InferInfo::process") << (*this) << std::endl;
+
return true;
}
diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp
index 7ef126911..708c25f34 100644
--- a/src/theory/bags/inference_generator.cpp
+++ b/src/theory/bags/inference_generator.cpp
@@ -80,17 +80,15 @@ struct BagsDeqAttributeId
};
typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
-InferInfo InferenceGenerator::bagDisequality(Node n, Node reason)
+InferInfo InferenceGenerator::bagDisequality(Node n)
{
- Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL);
- Assert(n[0][0].getType().isBag());
+ Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag());
- Node A = n[0][0];
- Node B = n[0][1];
+ Node A = n[0];
+ Node B = n[1];
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DISEQUALITY;
- inferInfo.d_premises.push_back(reason);
TypeNode elementType = A.getType().getBagElementType();
BoundVarManager* bvm = d_nm->getBoundVarManager();
@@ -106,7 +104,7 @@ InferInfo InferenceGenerator::bagDisequality(Node n, Node reason)
Node disEqual = countA.eqNode(countB).notNode();
- inferInfo.d_premises.push_back(n);
+ inferInfo.d_premises.push_back(n.notNode());
inferInfo.d_conclusion = disEqual;
return inferInfo;
}
@@ -118,13 +116,15 @@ Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo)
return skolem;
}
-InferInfo InferenceGenerator::bagEmpty(Node e)
+InferInfo InferenceGenerator::empty(Node n, Node e)
{
- EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType()));
- Node empty = d_nm->mkConst(emptyBag);
+ Assert(n.getKind() == kind::EMPTYBAG);
+ Assert(e.getType() == n.getType().getBagElementType());
+
InferInfo inferInfo;
+ Node skolem = getSkolem(n, inferInfo);
inferInfo.d_id = Inference::BAG_EMPTY;
- Node count = getMultiplicityTerm(e, empty);
+ Node count = getMultiplicityTerm(e, skolem);
Node equal = count.eqNode(d_zero);
inferInfo.d_conclusion = equal;
diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h
index 9eee46e43..4a852530a 100644
--- a/src/theory/bags/inference_generator.h
+++ b/src/theory/bags/inference_generator.h
@@ -62,22 +62,25 @@ class InferenceGenerator
*/
InferInfo mkBag(Node n, Node e);
/**
- * @param n is (not (= A B)) where A, B are bags of type (Bag E)
+ * @param n is (= A B) where A, B are bags of type (Bag E), and
+ * (not (= A B)) is an assertion in the equality engine
* @return an inference that represents the following implication
* (=>
* (not (= A B))
* (not (= (count e A) (count e B))))
* where e is a fresh skolem of type E.
*/
- InferInfo bagDisequality(Node n, Node reason);
+ InferInfo bagDisequality(Node n);
/**
+ * @param n is (as emptybag (Bag E))
* @param e is a node of Type E
* @return an inference that represents the following implication
* (=>
* true
- * (= 0 (count e (as emptybag (Bag E)))))
+ * (= 0 (count e skolem)))
+ * where skolem = (as emptybag (Bag String))
*/
- InferInfo bagEmpty(Node e);
+ InferInfo empty(Node n, Node e);
/**
* @param n is (union_disjoint A B) where A, B are bags of type (Bag E)
* @param e is a node of Type E
diff --git a/src/theory/bags/solver_state.cpp b/src/theory/bags/solver_state.cpp
index 9bcb6ae3c..adca85068 100644
--- a/src/theory/bags/solver_state.cpp
+++ b/src/theory/bags/solver_state.cpp
@@ -55,31 +55,32 @@ const std::set<Node>& SolverState::getBags() { return d_bags; }
const std::set<Node>& SolverState::getElements(Node B)
{
Node bag = getRepresentative(B);
- return d_bagElements[B];
+ return d_bagElements[bag];
}
+const std::set<Node>& SolverState::getDisequalBagTerms() { return d_deq; }
+
void SolverState::reset()
{
d_bagElements.clear();
d_bags.clear();
+ d_deq.clear();
}
void SolverState::initialize()
{
reset();
collectBagsAndCountTerms();
+ collectDisequalBagTerms();
}
void SolverState::collectBagsAndCountTerms()
{
- Trace("SolverState::collectBagsAndCountTerms")
- << "SolverState::collectBagsAndCountTerms start" << endl;
eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee);
while (!repIt.isFinished())
{
Node eqc = (*repIt);
- Trace("SolverState::collectBagsAndCountTerms")
- << "[" << eqc << "]: " << endl;
+ Trace("bags-eqc") << "Eqc [ " << eqc << " ] = { ";
if (eqc.getType().isBag())
{
@@ -90,6 +91,7 @@ void SolverState::collectBagsAndCountTerms()
while (!it.isFinished())
{
Node n = (*it);
+ Trace("bags-eqc") << (*it) << " ";
Kind k = n.getKind();
if (k == MK_BAG)
{
@@ -109,12 +111,27 @@ void SolverState::collectBagsAndCountTerms()
}
++it;
}
-
+ Trace("bags-eqc") << " } " << std::endl;
++repIt;
}
- Trace("SolverState::collectBagsAndCountTerms")
- << "SolverState::collectBagsAndCountTerms end" << endl;
+ Trace("bags-eqc") << "bag representatives: " << d_bags << endl;
+ Trace("bags-eqc") << "bag elements: " << d_bagElements << endl;
+}
+
+void SolverState::collectDisequalBagTerms()
+{
+ eq::EqClassIterator it = eq::EqClassIterator(d_false, d_ee);
+ while (!it.isFinished())
+ {
+ Node n = (*it);
+ if (n.getKind() == EQUAL && n[0].getType().isBag())
+ {
+ Trace("bags-eqc") << "Disequal terms: " << n << std::endl;
+ d_deq.insert(n);
+ }
+ ++it;
+ }
}
} // namespace bags
diff --git a/src/theory/bags/solver_state.h b/src/theory/bags/solver_state.h
index 175317529..7670e5dec 100644
--- a/src/theory/bags/solver_state.h
+++ b/src/theory/bags/solver_state.h
@@ -61,13 +61,21 @@ class SolverState : public TheoryState
const std::set<Node>& getElements(Node B);
/** initialize bag and count terms */
void initialize();
+ /** return disequal bag terms */
+ const std::set<Node>& getDisequalBagTerms();
private:
/** clear all bags data structures */
void reset();
- /** collect bags' representatives and all count terms.
- * This function is called during postCheck */
+ /**
+ * collect bags' representatives and all count terms.
+ * This function is called during postCheck
+ */
void collectBagsAndCountTerms();
+ /**
+ * collect disequal bag terms. This function is called during postCheck.
+ */
+ void collectDisequalBagTerms();
/** constants */
Node d_true;
Node d_false;
@@ -77,6 +85,8 @@ class SolverState : public TheoryState
std::set<Node> d_bags;
/** bag -> associated elements */
std::map<Node, std::set<Node>> d_bagElements;
+ /** Disequal bag terms */
+ std::set<Node> d_deq;
}; /* class SolverState */
} // namespace bags
diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp
index 153e9017d..15e8e00e7 100644
--- a/src/theory/bags/theory_bags.cpp
+++ b/src/theory/bags/theory_bags.cpp
@@ -199,7 +199,6 @@ void TheoryBags::preRegisterTerm(TNode n)
case BAG_FROM_SET:
case BAG_TO_SET:
case BAG_IS_SINGLETON:
- case DUPLICATE_REMOVAL:
{
std::stringstream ss;
ss << "Term of kind " << n.getKind() << " is not supported yet";
@@ -223,15 +222,7 @@ void TheoryBags::eqNotifyNewClass(TNode n) {}
void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {}
-void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason)
-{
- TypeNode t1 = n1.getType();
- if (t1.isBag())
- {
- InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode(), reason);
- info.process(d_inferManager, true);
- }
-}
+void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason) {}
void TheoryBags::NotifyClass::eqNotifyNewClass(TNode n)
{
diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt
index 94be987f7..128f9c567 100644
--- a/test/regress/CMakeLists.txt
+++ b/test/regress/CMakeLists.txt
@@ -1430,6 +1430,11 @@ set(regress_1_tests
regress1/bug800.smt2
regress1/bags/difference_remove1.smt2
regress1/bags/disequality.smt2
+ regress1/bags/duplicate_removal1.smt2
+ regress1/bags/duplicate_removal2.smt2
+ regress1/bags/emptybag1.smt2
+ regress1/bags/intersection_min1.smt2
+ regress1/bags/intersection_min2.smt2
regress1/bags/issue5759.smt2
regress1/bags/subbag1.smt2
regress1/bags/subbag2.smt2
diff --git a/test/regress/regress1/bags/duplicate_removal1.smt2 b/test/regress/regress1/bags/duplicate_removal1.smt2
new file mode 100644
index 000000000..2b662c6e5
--- /dev/null
+++ b/test/regress/regress1/bags/duplicate_removal1.smt2
@@ -0,0 +1,8 @@
+(set-logic ALL)
+(set-info :status sat)
+(set-option :produce-models true)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(assert (= B (duplicate_removal A)))
+(assert (distinct (as emptybag (Bag String)) A B))
+(check-sat)
diff --git a/test/regress/regress1/bags/duplicate_removal2.smt2 b/test/regress/regress1/bags/duplicate_removal2.smt2
new file mode 100644
index 000000000..7dacaae43
--- /dev/null
+++ b/test/regress/regress1/bags/duplicate_removal2.smt2
@@ -0,0 +1,8 @@
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(assert (= B (duplicate_removal A)))
+(assert (distinct (as emptybag (Bag String)) A B))
+(assert (= B (union_max A B)))
+(check-sat) \ No newline at end of file
diff --git a/test/regress/regress1/bags/emptybag1.smt2 b/test/regress/regress1/bags/emptybag1.smt2
new file mode 100644
index 000000000..f7f92599d
--- /dev/null
+++ b/test/regress/regress1/bags/emptybag1.smt2
@@ -0,0 +1,10 @@
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag String))
+(declare-fun x () String)
+(declare-fun y () Int)
+(assert (= x "x"))
+(assert (= A (as emptybag (Bag String))))
+(assert (= (bag.count x A) y))
+(assert(> y 1))
+(check-sat)
diff --git a/test/regress/regress1/bags/intersection_min1.smt2 b/test/regress/regress1/bags/intersection_min1.smt2
new file mode 100644
index 000000000..f5a515b9c
--- /dev/null
+++ b/test/regress/regress1/bags/intersection_min1.smt2
@@ -0,0 +1,10 @@
+(set-logic ALL)
+(set-info :status sat)
+(set-option :produce-models true)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(declare-fun C () (Bag String))
+(assert (= C (intersection_min A B)))
+(assert (distinct (as emptybag (Bag String)) C))
+(assert (distinct A B C))
+(check-sat) \ No newline at end of file
diff --git a/test/regress/regress1/bags/intersection_min2.smt2 b/test/regress/regress1/bags/intersection_min2.smt2
new file mode 100644
index 000000000..66afa2f3a
--- /dev/null
+++ b/test/regress/regress1/bags/intersection_min2.smt2
@@ -0,0 +1,9 @@
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(declare-fun C () (Bag String))
+(assert (= C (intersection_min A B)))
+(assert (= C (union_disjoint A B)))
+(assert (distinct (as emptybag (Bag String)) C))
+(check-sat)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback