From 8fb135c25038c617679f96dd40dfba3d2585380e Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Mon, 26 Oct 2020 20:13:23 -0500 Subject: Add DUPICATE_REMOVAL operator to bags (#5336) This PR adds duplicate removal operator to bags (also known as delta or squash). Other changes: print MK_BAG operator as "bag" in smt2 instead of "mkBag" renamed BAG_IS_INCLUDED operator to SUBBAG. --- src/theory/bags/bags_rewriter.cpp | 26 +++++++++++++++++++++----- src/theory/bags/bags_rewriter.h | 14 +++++++++++--- src/theory/bags/kinds | 6 ++++-- src/theory/bags/make_bag_op.cpp | 3 ++- src/theory/bags/normal_form.cpp | 25 +++++++++++++++++++++++++ src/theory/bags/normal_form.h | 7 +++++++ src/theory/bags/rewrites.cpp | 1 + src/theory/bags/rewrites.h | 1 + src/theory/bags/theory_bags.cpp | 1 + src/theory/bags/theory_bags_type_rules.h | 30 ++++++++++++++++++++++++------ 10 files changed, 97 insertions(+), 17 deletions(-) (limited to 'src/theory/bags') diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index 26c54d4ec..f0540e9b7 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -63,6 +63,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) { case MK_BAG: response = rewriteMakeBag(n); break; case BAG_COUNT: response = rewriteBagCount(n); break; + case DUPLICATE_REMOVAL: response = rewriteDuplicateRemoval(n); break; case UNION_MAX: response = rewriteUnionMax(n); break; case UNION_DISJOINT: response = rewriteUnionDisjoint(n); break; case INTERSECTION_MIN: response = rewriteIntersectionMin(n); break; @@ -98,7 +99,7 @@ RewriteResponse BagsRewriter::preRewrite(TNode n) switch (k) { case EQUAL: response = rewriteEqual(n); break; - case BAG_IS_INCLUDED: response = rewriteIsIncluded(n); break; + case SUBBAG: response = rewriteSubBag(n); break; default: response = BagsRewriteResponse(n, Rewrite::NONE); } @@ -127,9 +128,9 @@ BagsRewriteResponse BagsRewriter::rewriteEqual(const TNode& n) const return BagsRewriteResponse(n, Rewrite::NONE); } -BagsRewriteResponse BagsRewriter::rewriteIsIncluded(const TNode& n) const +BagsRewriteResponse BagsRewriter::rewriteSubBag(const TNode& n) const { - Assert(n.getKind() == BAG_IS_INCLUDED); + Assert(n.getKind() == SUBBAG); // (bag.is_included A B) = ((difference_subtract A B) == emptybag) Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType())); @@ -168,6 +169,21 @@ BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const return BagsRewriteResponse(n, Rewrite::NONE); } +BagsRewriteResponse BagsRewriter::rewriteDuplicateRemoval(const TNode& n) const +{ + Assert(n.getKind() == DUPLICATE_REMOVAL); + if (n[0].getKind() == MK_BAG && n[0][1].isConst() + && n[0][1].getConst().sgn() == 1) + { + // (duplicate_removal (mkBag x n)) = (mkBag x 1) + // where n is a positive constant + Node one = NodeManager::currentNM()->mkConst(Rational(1)); + Node bag = d_nm->mkBag(n[0][0].getType(), n[0][0], one); + return BagsRewriteResponse(bag, Rewrite::DUPLICATE_REMOVAL_MK_BAG); + } + return BagsRewriteResponse(n, Rewrite::NONE); +} + BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const { Assert(n.getKind() == UNION_MAX); @@ -453,8 +469,8 @@ BagsRewriteResponse BagsRewriter::rewriteToSet(const TNode& n) const { // (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x) // where n is a positive constant and T is the type of the bag's elements - Node bag = d_nm->mkSingleton(n[0][0].getType(), n[0][0]); - return BagsRewriteResponse(bag, Rewrite::TO_SINGLETON); + Node set = d_nm->mkSingleton(n[0][0].getType(), n[0][0]); + return BagsRewriteResponse(set, Rewrite::TO_SINGLETON); } return BagsRewriteResponse(n, Rewrite::NONE); } diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index d36a21ccf..8be6b948a 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -17,8 +17,8 @@ #ifndef CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H #define CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H -#include "theory/rewriter.h" #include "theory/bags/rewrites.h" +#include "theory/rewriter.h" namespace CVC4 { namespace theory { @@ -50,7 +50,7 @@ class BagsRewriter : public TheoryRewriter */ RewriteResponse postRewrite(TNode n) override; /** - * preRewrite nodes with kinds: EQUAL, BAG_IS_INCLUDED. + * preRewrite nodes with kinds: EQUAL, SUBBAG. * See the rewrite rules for these kinds below. */ RewriteResponse preRewrite(TNode n) override; @@ -66,7 +66,7 @@ class BagsRewriter : public TheoryRewriter * rewrites for n include: * - (bag.is_included A B) = ((difference_subtract A B) == emptybag) */ - BagsRewriteResponse rewriteIsIncluded(const TNode& n) const; + BagsRewriteResponse rewriteSubBag(const TNode& n) const; /** * rewrites for n include: @@ -76,6 +76,7 @@ class BagsRewriter : public TheoryRewriter * - otherwise = n */ BagsRewriteResponse rewriteMakeBag(const TNode& n) const; + /** * rewrites for n include: * - (bag.count x emptybag) = 0 @@ -84,6 +85,13 @@ class BagsRewriter : public TheoryRewriter */ BagsRewriteResponse rewriteBagCount(const TNode& n) const; + /** + * rewrites for n include: + * - (duplicate_removal (mkBag x n)) = (mkBag x 1) + * where n is a positive constant + */ + BagsRewriteResponse rewriteDuplicateRemoval(const TNode& n) const; + /** * rewrites for n include: * - (union_max A emptybag) = A diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index 86e89e0bd..f84b811e7 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -45,8 +45,9 @@ operator DIFFERENCE_SUBTRACT 2 "bag difference1 (subtracts multiplicities)" # {("a", 2), ("b", 3)} \\ {("a", 1)} = {("b", 3)} operator DIFFERENCE_REMOVE 2 "bag difference remove (removes shared elements)" -operator BAG_IS_INCLUDED 2 "inclusion predicate for bags (less than or equal multiplicities)" +operator SUBBAG 2 "inclusion predicate for bags (less than or equal multiplicities)" operator BAG_COUNT 2 "multiplicity of an element in a bag" +operator DUPLICATE_REMOVAL 1 "eliminate duplicates in a bag (also known as the delta operator,or the squash operator)" constant MK_BAG_OP \ ::CVC4::MakeBagOp \ @@ -74,8 +75,9 @@ typerule UNION_DISJOINT ::CVC4::theory::bags::BinaryOperatorTypeRule typerule INTERSECTION_MIN ::CVC4::theory::bags::BinaryOperatorTypeRule typerule DIFFERENCE_SUBTRACT ::CVC4::theory::bags::BinaryOperatorTypeRule typerule DIFFERENCE_REMOVE ::CVC4::theory::bags::BinaryOperatorTypeRule -typerule BAG_IS_INCLUDED ::CVC4::theory::bags::IsIncludedTypeRule +typerule SUBBAG ::CVC4::theory::bags::SubBagTypeRule typerule BAG_COUNT ::CVC4::theory::bags::CountTypeRule +typerule DUPLICATE_REMOVAL ::CVC4::theory::bags::DuplicateRemovalTypeRule typerule MK_BAG_OP "SimpleTypeRule" typerule MK_BAG ::CVC4::theory::bags::MkBagTypeRule typerule EMPTYBAG ::CVC4::theory::bags::EmptyBagTypeRule diff --git a/src/theory/bags/make_bag_op.cpp b/src/theory/bags/make_bag_op.cpp index 6a535afc2..b60822783 100644 --- a/src/theory/bags/make_bag_op.cpp +++ b/src/theory/bags/make_bag_op.cpp @@ -12,10 +12,11 @@ ** \brief a class for MK_BAG operator **/ +#include "make_bag_op.h" + #include #include "expr/type_node.h" -#include "make_bag_op.h" namespace CVC4 { diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp index f2dea62d3..081ed77aa 100644 --- a/src/theory/bags/normal_form.cpp +++ b/src/theory/bags/normal_form.cpp @@ -94,6 +94,7 @@ Node NormalForm::evaluate(TNode n) { case MK_BAG: return evaluateMakeBag(n); case BAG_COUNT: return evaluateBagCount(n); + case DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n); case UNION_DISJOINT: return evaluateUnionDisjoint(n); case UNION_MAX: return evaluateUnionMax(n); case INTERSECTION_MIN: return evaluateIntersectionMin(n); @@ -240,6 +241,30 @@ Node NormalForm::evaluateBagCount(TNode n) return nm->mkConst(Rational(0)); } +Node NormalForm::evaluateDuplicateRemoval(TNode n) +{ + Assert(n.getKind() == DUPLICATE_REMOVAL); + + // Examples + // -------- + // - (duplicate_removal (emptybag String)) = (emptybag String) + // - (duplicate_removal (mkBag "x" 4)) = (emptybag "x" 1) + // - (duplicate_removal (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = + // (disjoint_union (mkBag "x" 1) (mkBag "y" 1) + + std::map oldElements = getBagElements(n[0]); + // copy elements from the old bag + std::map newElements(oldElements); + Rational one = Rational(1); + std::map::iterator it; + for (it = newElements.begin(); it != newElements.end(); it++) + { + it->second = one; + } + Node bag = constructBagFromElements(n[0].getType(), newElements); + return bag; +} + Node NormalForm::evaluateUnionDisjoint(TNode n) { Assert(n.getKind() == UNION_DISJOINT); diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/normal_form.h index ef0edefff..5a7936fa3 100644 --- a/src/theory/bags/normal_form.h +++ b/src/theory/bags/normal_form.h @@ -114,6 +114,13 @@ class NormalForm */ static Node evaluateBagCount(TNode n); + /** + * @param n has the form (duplicate_removal A) where A is a constant bag + * @return a constant bag constructed from the elements in A where each + * element has multiplicity one + */ + static Node evaluateDuplicateRemoval(TNode n); + /** * evaluates union disjoint node such that the returned node is a canonical * bag that has the form diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp index be3b3cc71..d640bcdce 100644 --- a/src/theory/bags/rewrites.cpp +++ b/src/theory/bags/rewrites.cpp @@ -31,6 +31,7 @@ const char* toString(Rewrite r) case Rewrite::CONSTANT_EVALUATION: return "CONSTANT_EVALUATION"; case Rewrite::COUNT_EMPTY: return "COUNT_EMPTY"; case Rewrite::COUNT_MK_BAG: return "COUNT_MK_BAG"; + case Rewrite::DUPLICATE_REMOVAL_MK_BAG: return "DUPLICATE_REMOVAL_MK_BAG"; case Rewrite::FROM_SINGLETON: return "FROM_SINGLETON"; case Rewrite::IDENTICAL_NODES: return "IDENTICAL_NODES"; case Rewrite::INTERSECTION_EMPTY_LEFT: return "INTERSECTION_EMPTY_LEFT"; diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h index dc1921e24..36e30ca68 100644 --- a/src/theory/bags/rewrites.h +++ b/src/theory/bags/rewrites.h @@ -36,6 +36,7 @@ enum class Rewrite : uint32_t CONSTANT_EVALUATION, COUNT_EMPTY, COUNT_MK_BAG, + DUPLICATE_REMOVAL_MK_BAG, FROM_SINGLETON, IDENTICAL_NODES, INTERSECTION_EMPTY_LEFT, diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 9dcad9bb1..9f62ea1c6 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -63,6 +63,7 @@ void TheoryBags::finishInit() d_equalityEngine->addFunctionKind(DIFFERENCE_SUBTRACT); d_equalityEngine->addFunctionKind(DIFFERENCE_REMOVE); d_equalityEngine->addFunctionKind(BAG_COUNT); + d_equalityEngine->addFunctionKind(DUPLICATE_REMOVAL); d_equalityEngine->addFunctionKind(MK_BAG); d_equalityEngine->addFunctionKind(BAG_CARD); d_equalityEngine->addFunctionKind(BAG_FROM_SET); diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index 7767938ed..cece40c9e 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -61,18 +61,17 @@ struct BinaryOperatorTypeRule } }; /* struct BinaryOperatorTypeRule */ -struct IsIncludedTypeRule +struct SubBagTypeRule { static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) { - Assert(n.getKind() == kind::BAG_IS_INCLUDED); + Assert(n.getKind() == kind::SUBBAG); TypeNode bagType = n[0].getType(check); if (check) { if (!bagType.isBag()) { - throw TypeCheckingExceptionPrivate( - n, "BAG_IS_INCLUDED operating on non-bag"); + throw TypeCheckingExceptionPrivate(n, "SUBBAG operating on non-bag"); } TypeNode secondBagType = n[1].getType(check); if (secondBagType != bagType) @@ -80,13 +79,13 @@ struct IsIncludedTypeRule if (!bagType.isComparableTo(secondBagType)) { throw TypeCheckingExceptionPrivate( - n, "BAG_IS_INCLUDED operating on bags of different types"); + n, "SUBBAG operating on bags of different types"); } } } return nodeManager->booleanType(); } -}; /* struct IsIncludedTypeRule */ +}; /* struct SubBagTypeRule */ struct CountTypeRule { @@ -118,6 +117,25 @@ struct CountTypeRule } }; /* struct CountTypeRule */ +struct DuplicateRemovalTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) + { + Assert(n.getKind() == kind::DUPLICATE_REMOVAL); + TypeNode bagType = n[0].getType(check); + if (check) + { + if (!bagType.isBag()) + { + std::stringstream ss; + ss << "Applying DUPLICATE_REMOVAL on a non-bag argument in term " << n; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + return bagType; + } +}; /* struct DuplicateRemovalTypeRule */ + struct MkBagTypeRule { static TypeNode computeType(NodeManager* nm, TNode n, bool check) -- cgit v1.2.3