summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/theory/bags/bag_solver.cpp78
-rw-r--r--src/theory/bags/bag_solver.h11
-rw-r--r--src/theory/bags/infer_info.cpp3
-rw-r--r--src/theory/bags/inference_generator.cpp22
-rw-r--r--src/theory/bags/inference_generator.h11
-rw-r--r--src/theory/bags/solver_state.cpp33
-rw-r--r--src/theory/bags/solver_state.h14
-rw-r--r--src/theory/bags/theory_bags.cpp11
8 files changed, 129 insertions, 54 deletions
diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp
index 495f73723..bdd4a9b30 100644
--- a/src/theory/bags/bag_solver.cpp
+++ b/src/theory/bags/bag_solver.cpp
@@ -27,7 +27,7 @@ namespace theory {
namespace bags {
BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr)
- : d_state(s), d_im(im), d_termReg(tr)
+ : d_state(s), d_ig(&d_state), d_im(im), d_termReg(tr)
{
d_zero = NodeManager::currentNM()->mkConst(Rational(0));
d_one = NodeManager::currentNM()->mkConst(Rational(1));
@@ -41,6 +41,8 @@ void BagSolver::postCheck()
{
d_state.initialize();
+ checkDisequalBagTerms();
+
// At this point, all bag and count representatives should be in the solver
// state.
for (const Node& bag : d_state.getBags())
@@ -54,11 +56,14 @@ void BagSolver::postCheck()
Kind k = n.getKind();
switch (k)
{
+ case kind::EMPTYBAG: checkEmpty(n); break;
case kind::MK_BAG: checkMkBag(n); break;
case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
case kind::UNION_MAX: checkUnionMax(n); break;
+ case kind::INTERSECTION_MIN: checkIntersectionMin(n); break;
case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
+ case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
default: break;
}
it++;
@@ -91,16 +96,24 @@ set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
return elements;
}
+void BagSolver::checkEmpty(const Node& n)
+{
+ Assert(n.getKind() == EMPTYBAG);
+ for (const Node& e : d_state.getElements(n))
+ {
+ InferInfo i = d_ig.empty(n, e);
+ i.process(&d_im, true);
+ }
+}
+
void BagSolver::checkUnionDisjoint(const Node& n)
{
Assert(n.getKind() == UNION_DISJOINT);
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.unionDisjoint(n, e);
+ InferInfo i = d_ig.unionDisjoint(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
@@ -110,10 +123,19 @@ void BagSolver::checkUnionMax(const Node& n)
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.unionMax(n, e);
+ InferInfo i = d_ig.unionMax(n, e);
+ i.process(&d_im, true);
+ }
+}
+
+void BagSolver::checkIntersectionMin(const Node& n)
+{
+ Assert(n.getKind() == INTERSECTION_MIN);
+ std::set<Node> elements = getElementsForBinaryOperator(n);
+ for (const Node& e : elements)
+ {
+ InferInfo i = d_ig.intersection(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
@@ -123,10 +145,8 @@ void BagSolver::checkDifferenceSubtract(const Node& n)
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.differenceSubtract(n, e);
+ InferInfo i = d_ig.differenceSubtract(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
@@ -138,18 +158,14 @@ void BagSolver::checkMkBag(const Node& n)
<< " are: " << d_state.getElements(n) << std::endl;
for (const Node& e : d_state.getElements(n))
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.mkBag(n, e);
+ InferInfo i = d_ig.mkBag(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
}
void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.nonNegativeCount(bag, element);
+ InferInfo i = d_ig.nonNegativeCount(bag, element);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
}
void BagSolver::checkDifferenceRemove(const Node& n)
@@ -158,10 +174,34 @@ void BagSolver::checkDifferenceRemove(const Node& n)
std::set<Node> elements = getElementsForBinaryOperator(n);
for (const Node& e : elements)
{
- InferenceGenerator ig(&d_state);
- InferInfo i = ig.differenceRemove(n, e);
+ InferInfo i = d_ig.differenceRemove(n, e);
i.process(&d_im, true);
- Trace("bags::BagSolver::postCheck") << i << endl;
+ }
+}
+
+void BagSolver::checkDuplicateRemoval(Node n)
+{
+ Assert(n.getKind() == DUPLICATE_REMOVAL);
+ set<Node> elements;
+ const set<Node>& downwards = d_state.getElements(n);
+ const set<Node>& upwards = d_state.getElements(n[0]);
+
+ elements.insert(downwards.begin(), downwards.end());
+ elements.insert(upwards.begin(), upwards.end());
+
+ for (const Node& e : elements)
+ {
+ InferInfo i = d_ig.duplicateRemoval(n, e);
+ i.process(&d_im, true);
+ }
+}
+
+void BagSolver::checkDisequalBagTerms()
+{
+ for (const Node& n : d_state.getDisequalBagTerms())
+ {
+ InferInfo info = d_ig.bagDisequality(n);
+ info.process(&d_im, true);
}
}
diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h
index b4b18c00c..b19e1f11e 100644
--- a/src/theory/bags/bag_solver.h
+++ b/src/theory/bags/bag_solver.h
@@ -20,6 +20,7 @@
#include "context/cdhashset.h"
#include "context/cdlist.h"
#include "theory/bags/infer_info.h"
+#include "theory/bags/inference_generator.h"
#include "theory/bags/inference_manager.h"
#include "theory/bags/normal_form.h"
#include "theory/bags/solver_state.h"
@@ -41,6 +42,8 @@ class BagSolver
void postCheck();
private:
+ /** apply inference rules for empty bags */
+ void checkEmpty(const Node& n);
/**
* apply inference rules for MK_BAG operator.
* Example: Suppose n = (bag x c), and we have two count terms (bag.count x n)
@@ -60,15 +63,23 @@ class BagSolver
void checkUnionDisjoint(const Node& n);
/** apply inference rules for union max */
void checkUnionMax(const Node& n);
+ /** apply inference rules for intersection_min operator */
+ void checkIntersectionMin(const Node& n);
/** apply inference rules for difference subtract */
void checkDifferenceSubtract(const Node& n);
/** apply inference rules for difference remove */
void checkDifferenceRemove(const Node& n);
+ /** apply inference rules for duplicate removal operator */
+ void checkDuplicateRemoval(Node n);
/** apply non negative constraints for multiplicities */
void checkNonNegativeCountTerms(const Node& bag, const Node& element);
+ /** apply inference rules for disequal bag terms */
+ void checkDisequalBagTerms();
/** The solver state object */
SolverState& d_state;
+ /** The inference generator object*/
+ InferenceGenerator d_ig;
/** Reference to the inference manager for the theory of bags */
InferenceManager& d_im;
/** Reference to the term registry of theory of bags */
diff --git a/src/theory/bags/infer_info.cpp b/src/theory/bags/infer_info.cpp
index 5b3274617..9bf546af1 100644
--- a/src/theory/bags/infer_info.cpp
+++ b/src/theory/bags/infer_info.cpp
@@ -76,6 +76,9 @@ bool InferInfo::process(TheoryInferenceManager* im, bool asLemma)
TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr);
im->trustedLemma(trustedLemma);
}
+
+ Trace("bags::InferInfo::process") << (*this) << std::endl;
+
return true;
}
diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp
index 7ef126911..708c25f34 100644
--- a/src/theory/bags/inference_generator.cpp
+++ b/src/theory/bags/inference_generator.cpp
@@ -80,17 +80,15 @@ struct BagsDeqAttributeId
};
typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
-InferInfo InferenceGenerator::bagDisequality(Node n, Node reason)
+InferInfo InferenceGenerator::bagDisequality(Node n)
{
- Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL);
- Assert(n[0][0].getType().isBag());
+ Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag());
- Node A = n[0][0];
- Node B = n[0][1];
+ Node A = n[0];
+ Node B = n[1];
InferInfo inferInfo;
inferInfo.d_id = Inference::BAG_DISEQUALITY;
- inferInfo.d_premises.push_back(reason);
TypeNode elementType = A.getType().getBagElementType();
BoundVarManager* bvm = d_nm->getBoundVarManager();
@@ -106,7 +104,7 @@ InferInfo InferenceGenerator::bagDisequality(Node n, Node reason)
Node disEqual = countA.eqNode(countB).notNode();
- inferInfo.d_premises.push_back(n);
+ inferInfo.d_premises.push_back(n.notNode());
inferInfo.d_conclusion = disEqual;
return inferInfo;
}
@@ -118,13 +116,15 @@ Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo)
return skolem;
}
-InferInfo InferenceGenerator::bagEmpty(Node e)
+InferInfo InferenceGenerator::empty(Node n, Node e)
{
- EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType()));
- Node empty = d_nm->mkConst(emptyBag);
+ Assert(n.getKind() == kind::EMPTYBAG);
+ Assert(e.getType() == n.getType().getBagElementType());
+
InferInfo inferInfo;
+ Node skolem = getSkolem(n, inferInfo);
inferInfo.d_id = Inference::BAG_EMPTY;
- Node count = getMultiplicityTerm(e, empty);
+ Node count = getMultiplicityTerm(e, skolem);
Node equal = count.eqNode(d_zero);
inferInfo.d_conclusion = equal;
diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h
index 9eee46e43..4a852530a 100644
--- a/src/theory/bags/inference_generator.h
+++ b/src/theory/bags/inference_generator.h
@@ -62,22 +62,25 @@ class InferenceGenerator
*/
InferInfo mkBag(Node n, Node e);
/**
- * @param n is (not (= A B)) where A, B are bags of type (Bag E)
+ * @param n is (= A B) where A, B are bags of type (Bag E), and
+ * (not (= A B)) is an assertion in the equality engine
* @return an inference that represents the following implication
* (=>
* (not (= A B))
* (not (= (count e A) (count e B))))
* where e is a fresh skolem of type E.
*/
- InferInfo bagDisequality(Node n, Node reason);
+ InferInfo bagDisequality(Node n);
/**
+ * @param n is (as emptybag (Bag E))
* @param e is a node of Type E
* @return an inference that represents the following implication
* (=>
* true
- * (= 0 (count e (as emptybag (Bag E)))))
+ * (= 0 (count e skolem)))
+ * where skolem = (as emptybag (Bag String))
*/
- InferInfo bagEmpty(Node e);
+ InferInfo empty(Node n, Node e);
/**
* @param n is (union_disjoint A B) where A, B are bags of type (Bag E)
* @param e is a node of Type E
diff --git a/src/theory/bags/solver_state.cpp b/src/theory/bags/solver_state.cpp
index 9bcb6ae3c..adca85068 100644
--- a/src/theory/bags/solver_state.cpp
+++ b/src/theory/bags/solver_state.cpp
@@ -55,31 +55,32 @@ const std::set<Node>& SolverState::getBags() { return d_bags; }
const std::set<Node>& SolverState::getElements(Node B)
{
Node bag = getRepresentative(B);
- return d_bagElements[B];
+ return d_bagElements[bag];
}
+const std::set<Node>& SolverState::getDisequalBagTerms() { return d_deq; }
+
void SolverState::reset()
{
d_bagElements.clear();
d_bags.clear();
+ d_deq.clear();
}
void SolverState::initialize()
{
reset();
collectBagsAndCountTerms();
+ collectDisequalBagTerms();
}
void SolverState::collectBagsAndCountTerms()
{
- Trace("SolverState::collectBagsAndCountTerms")
- << "SolverState::collectBagsAndCountTerms start" << endl;
eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee);
while (!repIt.isFinished())
{
Node eqc = (*repIt);
- Trace("SolverState::collectBagsAndCountTerms")
- << "[" << eqc << "]: " << endl;
+ Trace("bags-eqc") << "Eqc [ " << eqc << " ] = { ";
if (eqc.getType().isBag())
{
@@ -90,6 +91,7 @@ void SolverState::collectBagsAndCountTerms()
while (!it.isFinished())
{
Node n = (*it);
+ Trace("bags-eqc") << (*it) << " ";
Kind k = n.getKind();
if (k == MK_BAG)
{
@@ -109,12 +111,27 @@ void SolverState::collectBagsAndCountTerms()
}
++it;
}
-
+ Trace("bags-eqc") << " } " << std::endl;
++repIt;
}
- Trace("SolverState::collectBagsAndCountTerms")
- << "SolverState::collectBagsAndCountTerms end" << endl;
+ Trace("bags-eqc") << "bag representatives: " << d_bags << endl;
+ Trace("bags-eqc") << "bag elements: " << d_bagElements << endl;
+}
+
+void SolverState::collectDisequalBagTerms()
+{
+ eq::EqClassIterator it = eq::EqClassIterator(d_false, d_ee);
+ while (!it.isFinished())
+ {
+ Node n = (*it);
+ if (n.getKind() == EQUAL && n[0].getType().isBag())
+ {
+ Trace("bags-eqc") << "Disequal terms: " << n << std::endl;
+ d_deq.insert(n);
+ }
+ ++it;
+ }
}
} // namespace bags
diff --git a/src/theory/bags/solver_state.h b/src/theory/bags/solver_state.h
index 175317529..7670e5dec 100644
--- a/src/theory/bags/solver_state.h
+++ b/src/theory/bags/solver_state.h
@@ -61,13 +61,21 @@ class SolverState : public TheoryState
const std::set<Node>& getElements(Node B);
/** initialize bag and count terms */
void initialize();
+ /** return disequal bag terms */
+ const std::set<Node>& getDisequalBagTerms();
private:
/** clear all bags data structures */
void reset();
- /** collect bags' representatives and all count terms.
- * This function is called during postCheck */
+ /**
+ * collect bags' representatives and all count terms.
+ * This function is called during postCheck
+ */
void collectBagsAndCountTerms();
+ /**
+ * collect disequal bag terms. This function is called during postCheck.
+ */
+ void collectDisequalBagTerms();
/** constants */
Node d_true;
Node d_false;
@@ -77,6 +85,8 @@ class SolverState : public TheoryState
std::set<Node> d_bags;
/** bag -> associated elements */
std::map<Node, std::set<Node>> d_bagElements;
+ /** Disequal bag terms */
+ std::set<Node> d_deq;
}; /* class SolverState */
} // namespace bags
diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp
index 153e9017d..15e8e00e7 100644
--- a/src/theory/bags/theory_bags.cpp
+++ b/src/theory/bags/theory_bags.cpp
@@ -199,7 +199,6 @@ void TheoryBags::preRegisterTerm(TNode n)
case BAG_FROM_SET:
case BAG_TO_SET:
case BAG_IS_SINGLETON:
- case DUPLICATE_REMOVAL:
{
std::stringstream ss;
ss << "Term of kind " << n.getKind() << " is not supported yet";
@@ -223,15 +222,7 @@ void TheoryBags::eqNotifyNewClass(TNode n) {}
void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {}
-void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason)
-{
- TypeNode t1 = n1.getType();
- if (t1.isBag())
- {
- InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode(), reason);
- info.process(d_inferManager, true);
- }
-}
+void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason) {}
void TheoryBags::NotifyClass::eqNotifyNewClass(TNode n)
{
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback