From 50c3dee5c8a4855023df826e1a733ea3c6076774 Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Fri, 29 Jan 2021 15:44:28 -0600 Subject: Add bag inferences for operators: intersection, duplicate_removal, and empty bags (#5832) This PR adds inferences for operators: intersection, duplicate_removal, and empty bags during post check. It also fixes a bug in SolverState::getElements --- src/theory/bags/bag_solver.cpp | 78 +++++++++++++++++++++++++-------- src/theory/bags/bag_solver.h | 11 +++++ src/theory/bags/infer_info.cpp | 3 ++ src/theory/bags/inference_generator.cpp | 22 +++++----- src/theory/bags/inference_generator.h | 11 +++-- src/theory/bags/solver_state.cpp | 33 ++++++++++---- src/theory/bags/solver_state.h | 14 +++++- src/theory/bags/theory_bags.cpp | 11 +---- 8 files changed, 129 insertions(+), 54 deletions(-) (limited to 'src/theory/bags') diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp index 495f73723..bdd4a9b30 100644 --- a/src/theory/bags/bag_solver.cpp +++ b/src/theory/bags/bag_solver.cpp @@ -27,7 +27,7 @@ namespace theory { namespace bags { BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr) - : d_state(s), d_im(im), d_termReg(tr) + : d_state(s), d_ig(&d_state), d_im(im), d_termReg(tr) { d_zero = NodeManager::currentNM()->mkConst(Rational(0)); d_one = NodeManager::currentNM()->mkConst(Rational(1)); @@ -41,6 +41,8 @@ void BagSolver::postCheck() { d_state.initialize(); + checkDisequalBagTerms(); + // At this point, all bag and count representatives should be in the solver // state. for (const Node& bag : d_state.getBags()) @@ -54,11 +56,14 @@ void BagSolver::postCheck() Kind k = n.getKind(); switch (k) { + case kind::EMPTYBAG: checkEmpty(n); break; case kind::MK_BAG: checkMkBag(n); break; case kind::UNION_DISJOINT: checkUnionDisjoint(n); break; case kind::UNION_MAX: checkUnionMax(n); break; + case kind::INTERSECTION_MIN: checkIntersectionMin(n); break; case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break; case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break; + case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break; default: break; } it++; @@ -91,16 +96,24 @@ set BagSolver::getElementsForBinaryOperator(const Node& n) return elements; } +void BagSolver::checkEmpty(const Node& n) +{ + Assert(n.getKind() == EMPTYBAG); + for (const Node& e : d_state.getElements(n)) + { + InferInfo i = d_ig.empty(n, e); + i.process(&d_im, true); + } +} + void BagSolver::checkUnionDisjoint(const Node& n) { Assert(n.getKind() == UNION_DISJOINT); std::set elements = getElementsForBinaryOperator(n); for (const Node& e : elements) { - InferenceGenerator ig(&d_state); - InferInfo i = ig.unionDisjoint(n, e); + InferInfo i = d_ig.unionDisjoint(n, e); i.process(&d_im, true); - Trace("bags::BagSolver::postCheck") << i << endl; } } @@ -110,10 +123,19 @@ void BagSolver::checkUnionMax(const Node& n) std::set elements = getElementsForBinaryOperator(n); for (const Node& e : elements) { - InferenceGenerator ig(&d_state); - InferInfo i = ig.unionMax(n, e); + InferInfo i = d_ig.unionMax(n, e); + i.process(&d_im, true); + } +} + +void BagSolver::checkIntersectionMin(const Node& n) +{ + Assert(n.getKind() == INTERSECTION_MIN); + std::set elements = getElementsForBinaryOperator(n); + for (const Node& e : elements) + { + InferInfo i = d_ig.intersection(n, e); i.process(&d_im, true); - Trace("bags::BagSolver::postCheck") << i << endl; } } @@ -123,10 +145,8 @@ void BagSolver::checkDifferenceSubtract(const Node& n) std::set elements = getElementsForBinaryOperator(n); for (const Node& e : elements) { - InferenceGenerator ig(&d_state); - InferInfo i = ig.differenceSubtract(n, e); + InferInfo i = d_ig.differenceSubtract(n, e); i.process(&d_im, true); - Trace("bags::BagSolver::postCheck") << i << endl; } } @@ -138,18 +158,14 @@ void BagSolver::checkMkBag(const Node& n) << " are: " << d_state.getElements(n) << std::endl; for (const Node& e : d_state.getElements(n)) { - InferenceGenerator ig(&d_state); - InferInfo i = ig.mkBag(n, e); + InferInfo i = d_ig.mkBag(n, e); i.process(&d_im, true); - Trace("bags::BagSolver::postCheck") << i << endl; } } void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element) { - InferenceGenerator ig(&d_state); - InferInfo i = ig.nonNegativeCount(bag, element); + InferInfo i = d_ig.nonNegativeCount(bag, element); i.process(&d_im, true); - Trace("bags::BagSolver::postCheck") << i << endl; } void BagSolver::checkDifferenceRemove(const Node& n) @@ -158,10 +174,34 @@ void BagSolver::checkDifferenceRemove(const Node& n) std::set elements = getElementsForBinaryOperator(n); for (const Node& e : elements) { - InferenceGenerator ig(&d_state); - InferInfo i = ig.differenceRemove(n, e); + InferInfo i = d_ig.differenceRemove(n, e); i.process(&d_im, true); - Trace("bags::BagSolver::postCheck") << i << endl; + } +} + +void BagSolver::checkDuplicateRemoval(Node n) +{ + Assert(n.getKind() == DUPLICATE_REMOVAL); + set elements; + const set& downwards = d_state.getElements(n); + const set& upwards = d_state.getElements(n[0]); + + elements.insert(downwards.begin(), downwards.end()); + elements.insert(upwards.begin(), upwards.end()); + + for (const Node& e : elements) + { + InferInfo i = d_ig.duplicateRemoval(n, e); + i.process(&d_im, true); + } +} + +void BagSolver::checkDisequalBagTerms() +{ + for (const Node& n : d_state.getDisequalBagTerms()) + { + InferInfo info = d_ig.bagDisequality(n); + info.process(&d_im, true); } } diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h index b4b18c00c..b19e1f11e 100644 --- a/src/theory/bags/bag_solver.h +++ b/src/theory/bags/bag_solver.h @@ -20,6 +20,7 @@ #include "context/cdhashset.h" #include "context/cdlist.h" #include "theory/bags/infer_info.h" +#include "theory/bags/inference_generator.h" #include "theory/bags/inference_manager.h" #include "theory/bags/normal_form.h" #include "theory/bags/solver_state.h" @@ -41,6 +42,8 @@ class BagSolver void postCheck(); private: + /** apply inference rules for empty bags */ + void checkEmpty(const Node& n); /** * apply inference rules for MK_BAG operator. * Example: Suppose n = (bag x c), and we have two count terms (bag.count x n) @@ -60,15 +63,23 @@ class BagSolver void checkUnionDisjoint(const Node& n); /** apply inference rules for union max */ void checkUnionMax(const Node& n); + /** apply inference rules for intersection_min operator */ + void checkIntersectionMin(const Node& n); /** apply inference rules for difference subtract */ void checkDifferenceSubtract(const Node& n); /** apply inference rules for difference remove */ void checkDifferenceRemove(const Node& n); + /** apply inference rules for duplicate removal operator */ + void checkDuplicateRemoval(Node n); /** apply non negative constraints for multiplicities */ void checkNonNegativeCountTerms(const Node& bag, const Node& element); + /** apply inference rules for disequal bag terms */ + void checkDisequalBagTerms(); /** The solver state object */ SolverState& d_state; + /** The inference generator object*/ + InferenceGenerator d_ig; /** Reference to the inference manager for the theory of bags */ InferenceManager& d_im; /** Reference to the term registry of theory of bags */ diff --git a/src/theory/bags/infer_info.cpp b/src/theory/bags/infer_info.cpp index 5b3274617..9bf546af1 100644 --- a/src/theory/bags/infer_info.cpp +++ b/src/theory/bags/infer_info.cpp @@ -76,6 +76,9 @@ bool InferInfo::process(TheoryInferenceManager* im, bool asLemma) TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr); im->trustedLemma(trustedLemma); } + + Trace("bags::InferInfo::process") << (*this) << std::endl; + return true; } diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp index 7ef126911..708c25f34 100644 --- a/src/theory/bags/inference_generator.cpp +++ b/src/theory/bags/inference_generator.cpp @@ -80,17 +80,15 @@ struct BagsDeqAttributeId }; typedef expr::Attribute BagsDeqAttribute; -InferInfo InferenceGenerator::bagDisequality(Node n, Node reason) +InferInfo InferenceGenerator::bagDisequality(Node n) { - Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL); - Assert(n[0][0].getType().isBag()); + Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag()); - Node A = n[0][0]; - Node B = n[0][1]; + Node A = n[0]; + Node B = n[1]; InferInfo inferInfo; inferInfo.d_id = Inference::BAG_DISEQUALITY; - inferInfo.d_premises.push_back(reason); TypeNode elementType = A.getType().getBagElementType(); BoundVarManager* bvm = d_nm->getBoundVarManager(); @@ -106,7 +104,7 @@ InferInfo InferenceGenerator::bagDisequality(Node n, Node reason) Node disEqual = countA.eqNode(countB).notNode(); - inferInfo.d_premises.push_back(n); + inferInfo.d_premises.push_back(n.notNode()); inferInfo.d_conclusion = disEqual; return inferInfo; } @@ -118,13 +116,15 @@ Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo) return skolem; } -InferInfo InferenceGenerator::bagEmpty(Node e) +InferInfo InferenceGenerator::empty(Node n, Node e) { - EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType())); - Node empty = d_nm->mkConst(emptyBag); + Assert(n.getKind() == kind::EMPTYBAG); + Assert(e.getType() == n.getType().getBagElementType()); + InferInfo inferInfo; + Node skolem = getSkolem(n, inferInfo); inferInfo.d_id = Inference::BAG_EMPTY; - Node count = getMultiplicityTerm(e, empty); + Node count = getMultiplicityTerm(e, skolem); Node equal = count.eqNode(d_zero); inferInfo.d_conclusion = equal; diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h index 9eee46e43..4a852530a 100644 --- a/src/theory/bags/inference_generator.h +++ b/src/theory/bags/inference_generator.h @@ -62,22 +62,25 @@ class InferenceGenerator */ InferInfo mkBag(Node n, Node e); /** - * @param n is (not (= A B)) where A, B are bags of type (Bag E) + * @param n is (= A B) where A, B are bags of type (Bag E), and + * (not (= A B)) is an assertion in the equality engine * @return an inference that represents the following implication * (=> * (not (= A B)) * (not (= (count e A) (count e B)))) * where e is a fresh skolem of type E. */ - InferInfo bagDisequality(Node n, Node reason); + InferInfo bagDisequality(Node n); /** + * @param n is (as emptybag (Bag E)) * @param e is a node of Type E * @return an inference that represents the following implication * (=> * true - * (= 0 (count e (as emptybag (Bag E))))) + * (= 0 (count e skolem))) + * where skolem = (as emptybag (Bag String)) */ - InferInfo bagEmpty(Node e); + InferInfo empty(Node n, Node e); /** * @param n is (union_disjoint A B) where A, B are bags of type (Bag E) * @param e is a node of Type E diff --git a/src/theory/bags/solver_state.cpp b/src/theory/bags/solver_state.cpp index 9bcb6ae3c..adca85068 100644 --- a/src/theory/bags/solver_state.cpp +++ b/src/theory/bags/solver_state.cpp @@ -55,31 +55,32 @@ const std::set& SolverState::getBags() { return d_bags; } const std::set& SolverState::getElements(Node B) { Node bag = getRepresentative(B); - return d_bagElements[B]; + return d_bagElements[bag]; } +const std::set& SolverState::getDisequalBagTerms() { return d_deq; } + void SolverState::reset() { d_bagElements.clear(); d_bags.clear(); + d_deq.clear(); } void SolverState::initialize() { reset(); collectBagsAndCountTerms(); + collectDisequalBagTerms(); } void SolverState::collectBagsAndCountTerms() { - Trace("SolverState::collectBagsAndCountTerms") - << "SolverState::collectBagsAndCountTerms start" << endl; eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee); while (!repIt.isFinished()) { Node eqc = (*repIt); - Trace("SolverState::collectBagsAndCountTerms") - << "[" << eqc << "]: " << endl; + Trace("bags-eqc") << "Eqc [ " << eqc << " ] = { "; if (eqc.getType().isBag()) { @@ -90,6 +91,7 @@ void SolverState::collectBagsAndCountTerms() while (!it.isFinished()) { Node n = (*it); + Trace("bags-eqc") << (*it) << " "; Kind k = n.getKind(); if (k == MK_BAG) { @@ -109,12 +111,27 @@ void SolverState::collectBagsAndCountTerms() } ++it; } - + Trace("bags-eqc") << " } " << std::endl; ++repIt; } - Trace("SolverState::collectBagsAndCountTerms") - << "SolverState::collectBagsAndCountTerms end" << endl; + Trace("bags-eqc") << "bag representatives: " << d_bags << endl; + Trace("bags-eqc") << "bag elements: " << d_bagElements << endl; +} + +void SolverState::collectDisequalBagTerms() +{ + eq::EqClassIterator it = eq::EqClassIterator(d_false, d_ee); + while (!it.isFinished()) + { + Node n = (*it); + if (n.getKind() == EQUAL && n[0].getType().isBag()) + { + Trace("bags-eqc") << "Disequal terms: " << n << std::endl; + d_deq.insert(n); + } + ++it; + } } } // namespace bags diff --git a/src/theory/bags/solver_state.h b/src/theory/bags/solver_state.h index 175317529..7670e5dec 100644 --- a/src/theory/bags/solver_state.h +++ b/src/theory/bags/solver_state.h @@ -61,13 +61,21 @@ class SolverState : public TheoryState const std::set& getElements(Node B); /** initialize bag and count terms */ void initialize(); + /** return disequal bag terms */ + const std::set& getDisequalBagTerms(); private: /** clear all bags data structures */ void reset(); - /** collect bags' representatives and all count terms. - * This function is called during postCheck */ + /** + * collect bags' representatives and all count terms. + * This function is called during postCheck + */ void collectBagsAndCountTerms(); + /** + * collect disequal bag terms. This function is called during postCheck. + */ + void collectDisequalBagTerms(); /** constants */ Node d_true; Node d_false; @@ -77,6 +85,8 @@ class SolverState : public TheoryState std::set d_bags; /** bag -> associated elements */ std::map> d_bagElements; + /** Disequal bag terms */ + std::set d_deq; }; /* class SolverState */ } // namespace bags diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 153e9017d..15e8e00e7 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -199,7 +199,6 @@ void TheoryBags::preRegisterTerm(TNode n) case BAG_FROM_SET: case BAG_TO_SET: case BAG_IS_SINGLETON: - case DUPLICATE_REMOVAL: { std::stringstream ss; ss << "Term of kind " << n.getKind() << " is not supported yet"; @@ -223,15 +222,7 @@ void TheoryBags::eqNotifyNewClass(TNode n) {} void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {} -void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason) -{ - TypeNode t1 = n1.getType(); - if (t1.isBag()) - { - InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode(), reason); - info.process(d_inferManager, true); - } -} +void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason) {} void TheoryBags::NotifyClass::eqNotifyNewClass(TNode n) { -- cgit v1.2.3