From 63df35c477ee9e6c7bdeae677656dd374563de55 Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Fri, 6 Nov 2020 15:28:38 -0600 Subject: Fix issue #5342 (#5349) This PR fixes issue #5342 by adding the rewrite rule (setminus A (setminus A B)) = (intersection A B). --- src/theory/sets/cardinality_extension.cpp | 3 ++ src/theory/sets/cardinality_extension.h | 2 +- src/theory/sets/solver_state.cpp | 7 ++++ src/theory/sets/solver_state.h | 7 ++++ src/theory/sets/theory_sets_private.cpp | 1 + src/theory/sets/theory_sets_rewriter.cpp | 56 ++++++++++++++++++++++--------- 6 files changed, 59 insertions(+), 17 deletions(-) (limited to 'src/theory/sets') diff --git a/src/theory/sets/cardinality_extension.cpp b/src/theory/sets/cardinality_extension.cpp index 21344ee73..cb0540b86 100644 --- a/src/theory/sets/cardinality_extension.cpp +++ b/src/theory/sets/cardinality_extension.cpp @@ -308,6 +308,9 @@ void CardinalityExtension::checkCardCycles() return; } } + + Trace("sets") << "d_card_parent: " << d_card_parent << std::endl; + Trace("sets") << "d_oSetEqc: " << d_oSetEqc << std::endl; Trace("sets") << "Done check cardinality cycles" << std::endl; } diff --git a/src/theory/sets/cardinality_extension.h b/src/theory/sets/cardinality_extension.h index 6704ce4a7..08424d3c4 100644 --- a/src/theory/sets/cardinality_extension.h +++ b/src/theory/sets/cardinality_extension.h @@ -167,7 +167,7 @@ class CardinalityExtension TermRegistry& d_treg; /** register cardinality term * - * This method add lemmas corresponding to the definition of + * This method adds lemmas corresponding to the definition of * the cardinality of set term n. For example, if n is A^B (denoting set * intersection as ^), then we consider the lemmas card(A^B)>=0, * card(A) = card(A\B) + card(A^B) and card(B) = card(B\A) + card(A^B). diff --git a/src/theory/sets/solver_state.cpp b/src/theory/sets/solver_state.cpp index cf9f4aa7a..1d58945a5 100644 --- a/src/theory/sets/solver_state.cpp +++ b/src/theory/sets/solver_state.cpp @@ -444,6 +444,13 @@ SolverState::getBinaryOpIndex() const { return d_bop_index; } + +const std::map>& SolverState::getBinaryOpIndex( + Kind k) +{ + return d_bop_index[k]; +} + const std::map >& SolverState::getOperatorList() const { return d_op_list; diff --git a/src/theory/sets/solver_state.h b/src/theory/sets/solver_state.h index 32b4d6113..41d3ac717 100644 --- a/src/theory/sets/solver_state.h +++ b/src/theory/sets/solver_state.h @@ -146,6 +146,13 @@ class SolverState : public TheoryState */ const std::map > >& getBinaryOpIndex() const; + + /** Get binary operator index + * + * This returns the binary operator index of the given kind. + * See getBinaryOpIndex() above. + */ + const std::map >& getBinaryOpIndex(Kind k); /** get operator list * * This returns a mapping from set kinds to a list of terms of that kind diff --git a/src/theory/sets/theory_sets_private.cpp b/src/theory/sets/theory_sets_private.cpp index a382688a9..e44c3c7a6 100644 --- a/src/theory/sets/theory_sets_private.cpp +++ b/src/theory/sets/theory_sets_private.cpp @@ -389,6 +389,7 @@ void TheorySetsPrivate::fullEffortCheck() } // check reduce comprehensions checkReduceComprehensions(); + d_im.doPendingLemmas(); if (d_im.hasSent()) { diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index 847bf34eb..1e4473d6f 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -119,32 +119,52 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { break; } - case kind::SETMINUS: { - if(node[0] == node[1]) { + case kind::SETMINUS: + { + if (node[0] == node[1]) + { Node newNode = nm->mkConst(EmptySet(node[0].getType())); - Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl; + Trace("sets-postrewrite") + << "Sets::postRewrite returning " << newNode << std::endl; return RewriteResponse(REWRITE_DONE, newNode); - } else if(node[0].getKind() == kind::EMPTYSET || - node[1].getKind() == kind::EMPTYSET) { - Trace("sets-postrewrite") << "Sets::postRewrite returning " << node[0] << std::endl; + } + else if (node[0].getKind() == kind::EMPTYSET + || node[1].getKind() == kind::EMPTYSET) + { + Trace("sets-postrewrite") + << "Sets::postRewrite returning " << node[0] << std::endl; return RewriteResponse(REWRITE_AGAIN, node[0]); - }else if( node[1].getKind() == kind::UNIVERSE_SET ){ + } + else if (node[1].getKind() == kind::SETMINUS && node[1][0] == node[0]) + { + // (setminus A (setminus A B)) = (intersection A B) + Node intersection = nm->mkNode(INTERSECTION, node[0], node[1][1]); + return RewriteResponse(REWRITE_AGAIN, intersection); + } + else if (node[1].getKind() == kind::UNIVERSE_SET) + { return RewriteResponse( REWRITE_AGAIN, NodeManager::currentNM()->mkConst(EmptySet(node[1].getType()))); - } else if(node[0].isConst() && node[1].isConst()) { + } + else if (node[0].isConst() && node[1].isConst()) + { std::set left = NormalForm::getElementsFromNormalConstant(node[0]); std::set right = NormalForm::getElementsFromNormalConstant(node[1]); std::set newSet; - std::set_difference(left.begin(), left.end(), right.begin(), right.end(), - std::inserter(newSet, newSet.begin())); + std::set_difference(left.begin(), + left.end(), + right.begin(), + right.end(), + std::inserter(newSet, newSet.begin())); Node newNode = NormalForm::elementsToSet(newSet, node.getType()); Assert(newNode.isConst()); - Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl; + Trace("sets-postrewrite") + << "Sets::postRewrite returning " << newNode << std::endl; return RewriteResponse(REWRITE_DONE, newNode); } break; - }//kind::SETMINUS + } // kind::SETMINUS case kind::INTERSECTION: { if(node[0] == node[1]) { @@ -203,11 +223,14 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { // we don't merge non-constant unions break; }//kind::UNION - case kind::COMPLEMENT: { - Node univ = NodeManager::currentNM()->mkNullaryOperator( node[0].getType(), kind::UNIVERSE_SET ); - return RewriteResponse( REWRITE_AGAIN, NodeManager::currentNM()->mkNode( kind::SETMINUS, univ, node[0] ) ); + case kind::COMPLEMENT: + { + Node univ = NodeManager::currentNM()->mkNullaryOperator(node[0].getType(), + kind::UNIVERSE_SET); + return RewriteResponse( + REWRITE_AGAIN, + NodeManager::currentNM()->mkNode(kind::SETMINUS, univ, node[0])); } - break; case kind::CARD: { if(node[0].isConst()) { std::set elements = NormalForm::getElementsFromNormalConstant(node[0]); @@ -510,6 +533,7 @@ RewriteResponse TheorySetsRewriter::preRewrite(TNode node) { nm->mkNode(kind::UNION, node[0], node[1]), node[1]) ); } + // could have an efficient normalizer for union here return RewriteResponse(REWRITE_DONE, node); -- cgit v1.2.3