diff options
author | mudathirmahgoub <mudathirmahgoub@gmail.com> | 2021-02-01 08:42:39 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-01 08:42:39 -0600 |
commit | c0937f742479d8a5054e42597da9447d55e876c0 (patch) | |
tree | 386fa3a7a46be381ba61ce132a538fe811cf16f9 | |
parent | c6a71382bad8bae30b8278055995279c36433811 (diff) |
Fix BagsRewriter::rewriteUnionDisjoint (#5840)
This PR fixes the implementation of (union_disjoint (union_max A B) (intersection_min A B)) =(union_disjoint A B).
It also skips processed bags during model building.
-rw-r--r-- | src/theory/bags/bags_rewriter.cpp | 2 | ||||
-rw-r--r-- | src/theory/bags/theory_bags.cpp | 11 | ||||
-rw-r--r-- | test/unit/theory/theory_bags_rewriter_white.h | 14 |
3 files changed, 25 insertions, 2 deletions
diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index 66886bfbf..9f53c29ca 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -246,7 +246,7 @@ BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b) // check if the operands of union_max and intersection_min are the same std::set<Node> left(n[0].begin(), n[0].end()); - std::set<Node> right(n[0].begin(), n[0].end()); + std::set<Node> right(n[1].begin(), n[1].end()); if (left == right) { Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]); diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 15e8e00e7..6df44295e 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -144,15 +144,26 @@ bool TheoryBags::collectModelValues(TheoryModel* m, Trace("bags-model") << "Term set: " << termSet << std::endl; + std::set<Node> processedBags; + // get the relevant bag equivalence classes for (const Node& n : termSet) { TypeNode tn = n.getType(); if (!tn.isBag()) { + // we are only concerned here about bag terms continue; } Node r = d_state.getRepresentative(n); + if (processedBags.find(r) != processedBags.end()) + { + // skip bags whose representatives are already processed + continue; + } + + processedBags.insert(r); + std::set<Node> solverElements = d_state.getElements(r); std::set<Node> elements; // only consider terms in termSet and ignore other elements in the solver diff --git a/test/unit/theory/theory_bags_rewriter_white.h b/test/unit/theory/theory_bags_rewriter_white.h index 98e3cf887..10a624238 100644 --- a/test/unit/theory/theory_bags_rewriter_white.h +++ b/test/unit/theory/theory_bags_rewriter_white.h @@ -280,16 +280,20 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite void testUnionDisjoint() { int n = 3; - vector<Node> elements = getNStrings(2); + vector<Node> elements = getNStrings(3); Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); Node A = d_nm->mkBag( d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n))); Node B = d_nm->mkBag( d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1))); + Node C = d_nm->mkBag( + d_nm->stringType(), elements[2], d_nm->mkConst(Rational(n + 2))); + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B); + Node unionMaxAC = d_nm->mkNode(UNION_MAX, A, C); Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A); Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B); Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A); @@ -321,6 +325,14 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4); TS_ASSERT(response4.d_node == unionDisjointBA && response4.d_status == REWRITE_AGAIN_FULL); + + // (union_disjoint (intersection_min B A)) (union_max A B) = + // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b) + Node unionDisjoint5 = + d_nm->mkNode(UNION_DISJOINT, unionMaxAC, intersectionAB); + RewriteResponse response5 = d_rewriter->postRewrite(unionDisjoint5); + TS_ASSERT(response5.d_node == unionDisjoint5 + && response5.d_status == REWRITE_DONE); } void testIntersectionMin() |