summaryrefslogtreecommitdiff
path: root/src/theory/sets
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2020-09-14 12:49:58 -0500
committerGitHub <noreply@github.com>2020-09-14 12:49:58 -0500
commit92a007b4a35a925c92eafc29df5bacacac75f6f9 (patch)
treeaf69c6aa8cf3ff5470c25aca27d8d340d9d0b141 /src/theory/sets
parentc82d061abb0c011da2700051b7a0548f5d59904b (diff)
Refactoring the rewriter of sets (#4856)
Changes it so that we don't flatten unions if at least one child is non-constant, since this may lead to children that are non-constant by mixing constant/non-constant elements and is generally expensive for large unions of singleton elements. The previous rewriting policy was causing an incorrect model in a separation logic benchmark reported by Andrew Jones, due to unions of constant elements that were unsorted (and hence not considered constants). We now have the invariant that all subterms that are unions of constant elements are set constants. Note this PR changes the normal form of set constants to be (union (singleton c1) ... (union (singleton cn-1) (singleton cn) ... ) not (union ... (union (singleton c1) (singleton c2)) ... (singleton cn)). It changes a unit test which was impacted by this change which was failing due to hardcoding the enumeration order in the unit test. The test is now agnostic to the order of elements.
Diffstat (limited to 'src/theory/sets')
-rw-r--r--src/theory/sets/normal_form.h109
-rw-r--r--src/theory/sets/theory_sets_private.cpp2
-rw-r--r--src/theory/sets/theory_sets_rewriter.cpp70
-rw-r--r--src/theory/sets/theory_sets_rewriter.h6
4 files changed, 97 insertions, 90 deletions
diff --git a/src/theory/sets/normal_form.h b/src/theory/sets/normal_form.h
index 0607a0e6c..b53a1c03d 100644
--- a/src/theory/sets/normal_form.h
+++ b/src/theory/sets/normal_form.h
@@ -25,6 +25,12 @@ namespace sets {
class NormalForm {
public:
+ /**
+ * Constructs a set of the form:
+ * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n))))
+ * from the set { c1 ... cn }, also handles empty set case, which is why
+ * setType is passed to this method.
+ */
template <bool ref_count>
static Node elementsToSet(const std::set<NodeTemplate<ref_count> >& elements,
TypeNode setType)
@@ -42,12 +48,21 @@ class NormalForm {
Node cur = nm->mkNode(kind::SINGLETON, *it);
while (++it != elements.end())
{
- cur = nm->mkNode(kind::UNION, cur, nm->mkNode(kind::SINGLETON, *it));
+ cur = nm->mkNode(kind::UNION, nm->mkNode(kind::SINGLETON, *it), cur);
}
return cur;
}
}
+ /**
+ * Returns true if n is considered a to be a (canonical) constant set value.
+ * A canonical set value is one whose AST is:
+ * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n))))
+ * where c1 ... cn are constants and the node identifier of these constants
+ * are such that:
+ * c1 > ... > cn.
+ * Also handles the corner cases of empty set and singleton set.
+ */
static bool checkNormalConstant(TNode n) {
Debug("sets-checknormal") << "[sets-checknormal] checkNormal " << n << " :"
<< std::endl;
@@ -56,46 +71,62 @@ class NormalForm {
} else if (n.getKind() == kind::SINGLETON) {
return n[0].isConst();
} else if (n.getKind() == kind::UNION) {
- // assuming (union ... (union {SmallestNodeID} {BiggerNodeId}) ...
- // {BiggestNodeId})
-
- // store BiggestNodeId in prvs
- if (n[1].getKind() != kind::SINGLETON) return false;
- if (!n[1][0].isConst()) return false;
- Debug("sets-checknormal")
- << "[sets-checknormal] frst element = " << n[1][0] << " "
- << n[1][0].getId() << std::endl;
- TNode prvs = n[1][0];
- n = n[0];
+ // assuming (union {SmallestNodeID} ... (union {BiggerNodeId} ...
+ Node orig = n;
+ TNode prvs;
// check intermediate nodes
- while (n.getKind() == kind::UNION) {
- if (n[1].getKind() != kind::SINGLETON) return false;
- if (!n[1].isConst()) return false;
+ while (n.getKind() == kind::UNION)
+ {
+ if (n[0].getKind() != kind::SINGLETON || !n[0][0].isConst())
+ {
+ // not a constant
+ Trace("sets-isconst") << "sets::isConst: " << orig << " not due to "
+ << n[0] << std::endl;
+ return false;
+ }
Debug("sets-checknormal")
- << "[sets-checknormal] element = " << n[1][0] << " "
- << n[1][0].getId() << std::endl;
- if (n[1][0] >= prvs) return false;
- prvs = n[1][0];
- n = n[0];
+ << "[sets-checknormal] element = " << n[0][0] << " "
+ << n[0][0].getId() << std::endl;
+ if (!prvs.isNull() && n[0][0] >= prvs)
+ {
+ Trace("sets-isconst")
+ << "sets::isConst: " << orig << " not due to compare " << n[0][0]
+ << std::endl;
+ return false;
+ }
+ prvs = n[0][0];
+ n = n[1];
}
// check SmallestNodeID is smallest
- if (n.getKind() != kind::SINGLETON) return false;
- if (!n[0].isConst()) return false;
+ if (n.getKind() != kind::SINGLETON || !n[0].isConst())
+ {
+ Trace("sets-isconst") << "sets::isConst: " << orig
+ << " not due to final " << n << std::endl;
+ return false;
+ }
Debug("sets-checknormal")
<< "[sets-checknormal] lst element = " << n[0] << " "
<< n[0].getId() << std::endl;
- if (n[0] >= prvs) return false;
-
- // we made it
- return true;
-
- } else {
- return false;
+ // compare last ID
+ if (n[0] < prvs)
+ {
+ return true;
+ }
+ Trace("sets-isconst")
+ << "sets::isConst: " << orig << " not due to compare final " << n[0]
+ << std::endl;
}
+ return false;
}
+ /**
+ * Converts a set term to a std::set of its elements. This expects a set of
+ * the form:
+ * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n))))
+ * Also handles the corner cases of empty set and singleton set.
+ */
static std::set<Node> getElementsFromNormalConstant(TNode n) {
Assert(n.isConst());
std::set<Node> ret;
@@ -103,29 +134,15 @@ class NormalForm {
return ret;
}
while (n.getKind() == kind::UNION) {
- Assert(n[1].getKind() == kind::SINGLETON);
- ret.insert(ret.begin(), n[1][0]);
- n = n[0];
+ Assert(n[0].getKind() == kind::SINGLETON);
+ ret.insert(ret.begin(), n[0][0]);
+ n = n[1];
}
Assert(n.getKind() == kind::SINGLETON);
ret.insert(n[0]);
return ret;
}
-
- //AJR
-
- static void getElementsFromBop( Kind k, Node n, std::vector< Node >& els ){
- if( n.getKind()==k ){
- for( unsigned i=0; i<n.getNumChildren(); i++ ){
- getElementsFromBop( k, n[i], els );
- }
- }else{
- if( std::find( els.begin(), els.end(), n )==els.end() ){
- els.push_back( n );
- }
- }
- }
static Node mkBop( Kind k, std::vector< Node >& els, TypeNode tn, unsigned index = 0 ){
if( index>=els.size() ){
return NodeManager::currentNM()->mkConst(EmptySet(tn));
diff --git a/src/theory/sets/theory_sets_private.cpp b/src/theory/sets/theory_sets_private.cpp
index 741f45dd8..b1831f261 100644
--- a/src/theory/sets/theory_sets_private.cpp
+++ b/src/theory/sets/theory_sets_private.cpp
@@ -320,7 +320,7 @@ void TheorySetsPrivate::fullEffortCheck()
Node n = (*eqc_i);
if (n != eqc)
{
- Trace("sets-eqc") << n << " ";
+ Trace("sets-eqc") << n << " (" << n.isConst() << ") ";
}
TypeNode tnn = n.getType();
if (isSet)
diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp
index eb168c6ed..50aa89cc8 100644
--- a/src/theory/sets/theory_sets_rewriter.cpp
+++ b/src/theory/sets/theory_sets_rewriter.cpp
@@ -27,7 +27,7 @@ namespace CVC4 {
namespace theory {
namespace sets {
-bool checkConstantMembership(TNode elementTerm, TNode setTerm)
+bool TheorySetsRewriter::checkConstantMembership(TNode elementTerm, TNode setTerm)
{
if(setTerm.getKind() == kind::EMPTYSET) {
return false;
@@ -38,12 +38,11 @@ bool checkConstantMembership(TNode elementTerm, TNode setTerm)
}
Assert(setTerm.getKind() == kind::UNION
- && setTerm[1].getKind() == kind::SINGLETON)
+ && setTerm[0].getKind() == kind::SINGLETON)
<< "kind was " << setTerm.getKind() << ", term: " << setTerm;
- return
- elementTerm == setTerm[1][0] ||
- checkConstantMembership(elementTerm, setTerm[0]);
+ return elementTerm == setTerm[0][0]
+ || checkConstantMembership(elementTerm, setTerm[1]);
}
// static
@@ -53,6 +52,8 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
Trace("sets-postrewrite") << "Process: " << node << std::endl;
if(node.isConst()) {
+ Trace("sets-rewrite-nf")
+ << "Sets::rewrite: no rewrite (constant) " << node << std::endl;
// Dare you touch the const and mangle it to something else.
return RewriteResponse(REWRITE_DONE, node);
}
@@ -163,23 +164,13 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
Assert(newNode.isConst());
Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
return RewriteResponse(REWRITE_DONE, newNode);
- } else {
- std::vector< Node > els;
- NormalForm::getElementsFromBop( kind::INTERSECTION, node, els );
- std::sort( els.begin(), els.end() );
- Node rew = NormalForm::mkBop( kind::INTERSECTION, els, node.getType() );
- if( rew!=node ){
- Trace("sets-rewrite") << "Sets::rewrite " << node << " -> " << rew << std::endl;
- }
- return RewriteResponse(REWRITE_DONE, rew);
}
- /*
- } else if (node[0] > node[1]) {
+ else if (node[0] > node[1])
+ {
Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
- Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
return RewriteResponse(REWRITE_DONE, newNode);
}
- */
+ // we don't merge non-constant intersections
break;
}//kind::INTERSECION
@@ -200,19 +191,16 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
std::inserter(newSet, newSet.begin()));
Node newNode = NormalForm::elementsToSet(newSet, node.getType());
Assert(newNode.isConst());
- Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
+ Trace("sets-rewrite")
+ << "Sets::rewrite: UNION_CONSTANT_MERGE: " << newNode << std::endl;
return RewriteResponse(REWRITE_DONE, newNode);
- } else {
- std::vector< Node > els;
- NormalForm::getElementsFromBop( kind::UNION, node, els );
- std::sort( els.begin(), els.end() );
- Node rew = NormalForm::mkBop( kind::UNION, els, node.getType() );
- if( rew!=node ){
- Trace("sets-rewrite") << "Sets::rewrite " << node << " -> " << rew << std::endl;
- }
- Trace("sets-rewrite") << "...no rewrite." << std::endl;
- return RewriteResponse(REWRITE_DONE, rew);
}
+ else if (node[0] > node[1])
+ {
+ Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
+ return RewriteResponse(REWRITE_DONE, newNode);
+ }
+ // we don't merge non-constant unions
break;
}//kind::UNION
case kind::COMPLEMENT: {
@@ -491,16 +479,15 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
// static
RewriteResponse TheorySetsRewriter::preRewrite(TNode node) {
NodeManager* nm = NodeManager::currentNM();
-
- if(node.getKind() == kind::EQUAL) {
-
+ Kind k = node.getKind();
+ if (k == kind::EQUAL)
+ {
if(node[0] == node[1]) {
return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
}
-
- }//kind::EQUAL
- else if(node.getKind() == kind::INSERT) {
-
+ }
+ else if (k == kind::INSERT)
+ {
Node insertedElements = nm->mkNode(kind::SINGLETON, node[0]);
size_t setNodeIndex = node.getNumChildren()-1;
for(size_t i = 1; i < setNodeIndex; ++i) {
@@ -512,17 +499,16 @@ RewriteResponse TheorySetsRewriter::preRewrite(TNode node) {
nm->mkNode(kind::UNION,
insertedElements,
node[setNodeIndex]));
-
- }//kind::INSERT
- else if(node.getKind() == kind::SUBSET) {
-
+ }
+ else if (k == kind::SUBSET)
+ {
// rewrite (A subset-or-equal B) as (A union B = B)
return RewriteResponse(REWRITE_AGAIN,
nm->mkNode(kind::EQUAL,
nm->mkNode(kind::UNION, node[0], node[1]),
node[1]) );
-
- }//kind::SUBSET
+ }
+ // could have an efficient normalizer for union here
return RewriteResponse(REWRITE_DONE, node);
}
diff --git a/src/theory/sets/theory_sets_rewriter.h b/src/theory/sets/theory_sets_rewriter.h
index 7d1a6c188..fdc9caefb 100644
--- a/src/theory/sets/theory_sets_rewriter.h
+++ b/src/theory/sets/theory_sets_rewriter.h
@@ -70,7 +70,11 @@ class TheorySetsRewriter : public TheoryRewriter
// often this will suffice
return postRewrite(equality).d_node;
}
-
+private:
+ /**
+ * Returns true if elementTerm is in setTerm, where both terms are constants.
+ */
+ bool checkConstantMembership(TNode elementTerm, TNode setTerm);
}; /* class TheorySetsRewriter */
}/* CVC4::theory::sets namespace */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback