summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormudathirmahgoub <mudathirmahgoub@gmail.com>2021-02-01 08:42:39 -0600
committerGitHub <noreply@github.com>2021-02-01 08:42:39 -0600
commitc0937f742479d8a5054e42597da9447d55e876c0 (patch)
tree386fa3a7a46be381ba61ce132a538fe811cf16f9
parentc6a71382bad8bae30b8278055995279c36433811 (diff)
Fix BagsRewriter::rewriteUnionDisjoint (#5840)
This PR fixes the implementation of (union_disjoint (union_max A B) (intersection_min A B)) =(union_disjoint A B). It also skips processed bags during model building.
-rw-r--r--src/theory/bags/bags_rewriter.cpp2
-rw-r--r--src/theory/bags/theory_bags.cpp11
-rw-r--r--test/unit/theory/theory_bags_rewriter_white.h14
3 files changed, 25 insertions, 2 deletions
diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp
index 66886bfbf..9f53c29ca 100644
--- a/src/theory/bags/bags_rewriter.cpp
+++ b/src/theory/bags/bags_rewriter.cpp
@@ -246,7 +246,7 @@ BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
// (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
// check if the operands of union_max and intersection_min are the same
std::set<Node> left(n[0].begin(), n[0].end());
- std::set<Node> right(n[0].begin(), n[0].end());
+ std::set<Node> right(n[1].begin(), n[1].end());
if (left == right)
{
Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]);
diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp
index 15e8e00e7..6df44295e 100644
--- a/src/theory/bags/theory_bags.cpp
+++ b/src/theory/bags/theory_bags.cpp
@@ -144,15 +144,26 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
Trace("bags-model") << "Term set: " << termSet << std::endl;
+ std::set<Node> processedBags;
+
// get the relevant bag equivalence classes
for (const Node& n : termSet)
{
TypeNode tn = n.getType();
if (!tn.isBag())
{
+ // we are only concerned here about bag terms
continue;
}
Node r = d_state.getRepresentative(n);
+ if (processedBags.find(r) != processedBags.end())
+ {
+ // skip bags whose representatives are already processed
+ continue;
+ }
+
+ processedBags.insert(r);
+
std::set<Node> solverElements = d_state.getElements(r);
std::set<Node> elements;
// only consider terms in termSet and ignore other elements in the solver
diff --git a/test/unit/theory/theory_bags_rewriter_white.h b/test/unit/theory/theory_bags_rewriter_white.h
index 98e3cf887..10a624238 100644
--- a/test/unit/theory/theory_bags_rewriter_white.h
+++ b/test/unit/theory/theory_bags_rewriter_white.h
@@ -280,16 +280,20 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite
void testUnionDisjoint()
{
int n = 3;
- vector<Node> elements = getNStrings(2);
+ vector<Node> elements = getNStrings(3);
Node emptyBag =
d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
Node A = d_nm->mkBag(
d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
Node B = d_nm->mkBag(
d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node C = d_nm->mkBag(
+ d_nm->stringType(), elements[2], d_nm->mkConst(Rational(n + 2)));
+
Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxAC = d_nm->mkNode(UNION_MAX, A, C);
Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
@@ -321,6 +325,14 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite
RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
TS_ASSERT(response4.d_node == unionDisjointBA
&& response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_disjoint (intersection_min B A)) (union_max A B) =
+ // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
+ Node unionDisjoint5 =
+ d_nm->mkNode(UNION_DISJOINT, unionMaxAC, intersectionAB);
+ RewriteResponse response5 = d_rewriter->postRewrite(unionDisjoint5);
+ TS_ASSERT(response5.d_node == unionDisjoint5
+ && response5.d_status == REWRITE_DONE);
}
void testIntersectionMin()
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback