diff options
Diffstat (limited to 'src/theory/bags/inference_generator.cpp')
-rw-r--r-- | src/theory/bags/inference_generator.cpp | 170 |
1 files changed, 89 insertions, 81 deletions
diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp index 734572f7c..e88a7e0ca 100644 --- a/src/theory/bags/inference_generator.cpp +++ b/src/theory/bags/inference_generator.cpp @@ -24,6 +24,8 @@ #include "theory/uf/equality_engine.h" #include "util/rational.h" +using namespace cvc5::kind; + namespace cvc5 { namespace theory { namespace bags { @@ -44,38 +46,49 @@ InferInfo InferenceGenerator::nonNegativeCount(Node n, Node e) Assert(e.getType() == n.getType().getBagElementType()); InferInfo inferInfo(d_im, InferenceId::BAGS_NON_NEGATIVE_COUNT); - Node count = d_nm->mkNode(kind::BAG_COUNT, e, n); + Node count = d_nm->mkNode(BAG_COUNT, e, n); - Node gte = d_nm->mkNode(kind::GEQ, count, d_zero); + Node gte = d_nm->mkNode(GEQ, count, d_zero); inferInfo.d_conclusion = gte; return inferInfo; } InferInfo InferenceGenerator::mkBag(Node n, Node e) { - Assert(n.getKind() == kind::MK_BAG); + Assert(n.getKind() == MK_BAG); Assert(e.getType() == n.getType().getBagElementType()); - if (n[0] == e) + Node x = n[0]; + Node c = n[1]; + Node geq = d_nm->mkNode(GEQ, c, d_one); + if (d_state->areEqual(e, x)) + { + // (= (bag.count e skolem) (ite (>= c 1) c 0))) + InferInfo inferInfo(d_im, InferenceId::BAGS_MK_BAG_SAME_ELEMENT); + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); + Node ite = d_nm->mkNode(ITE, geq, c, d_zero); + inferInfo.d_conclusion = count.eqNode(ite); + return inferInfo; + } + if (d_state->areDisequal(e, x)) { - // TODO issue #78: refactor this with BagRewriter - // (=> true (= (bag.count e (bag e c)) c)) + //(= (bag.count e skolem) 0)) InferInfo inferInfo(d_im, InferenceId::BAGS_MK_BAG_SAME_ELEMENT); Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); - inferInfo.d_conclusion = count.eqNode(n[1]); + inferInfo.d_conclusion = count.eqNode(d_zero); return inferInfo; } else { - // (=> - // true - // (= (bag.count e (bag x c)) (ite (= e x) c 0))) + // (= (bag.count e skolem) (ite (and (= e x) (>= c 1)) c 0))) InferInfo inferInfo(d_im, InferenceId::BAGS_MK_BAG); Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); - Node same = d_nm->mkNode(kind::EQUAL, n[0], e); - Node ite = d_nm->mkNode(kind::ITE, same, n[1], d_zero); + Node same = d_nm->mkNode(EQUAL, e, x); + Node andNode = same.andNode(geq); + Node ite = d_nm->mkNode(ITE, andNode, c, d_zero); Node equal = count.eqNode(ite); inferInfo.d_conclusion = equal; return inferInfo; @@ -110,7 +123,7 @@ typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute; InferInfo InferenceGenerator::bagDisequality(Node n) { - Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag()); + Assert(n.getKind() == EQUAL && n[0].getType().isBag()); Node A = n[0]; Node B = n[1]; @@ -145,7 +158,7 @@ Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo) InferInfo InferenceGenerator::empty(Node n, Node e) { - Assert(n.getKind() == kind::EMPTYBAG); + Assert(n.getKind() == EMPTYBAG); Assert(e.getType() == n.getType().getBagElementType()); InferInfo inferInfo(d_im, InferenceId::BAGS_EMPTY); @@ -159,7 +172,7 @@ InferInfo InferenceGenerator::empty(Node n, Node e) InferInfo InferenceGenerator::unionDisjoint(Node n, Node e) { - Assert(n.getKind() == kind::UNION_DISJOINT && n[0].getType().isBag()); + Assert(n.getKind() == UNION_DISJOINT && n[0].getType().isBag()); Assert(e.getType() == n[0].getType().getBagElementType()); Node A = n[0]; @@ -172,7 +185,7 @@ InferInfo InferenceGenerator::unionDisjoint(Node n, Node e) Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); - Node sum = d_nm->mkNode(kind::PLUS, countA, countB); + Node sum = d_nm->mkNode(PLUS, countA, countB); Node equal = count.eqNode(sum); inferInfo.d_conclusion = equal; @@ -181,7 +194,7 @@ InferInfo InferenceGenerator::unionDisjoint(Node n, Node e) InferInfo InferenceGenerator::unionMax(Node n, Node e) { - Assert(n.getKind() == kind::UNION_MAX && n[0].getType().isBag()); + Assert(n.getKind() == UNION_MAX && n[0].getType().isBag()); Assert(e.getType() == n[0].getType().getBagElementType()); Node A = n[0]; @@ -194,8 +207,8 @@ InferInfo InferenceGenerator::unionMax(Node n, Node e) Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); - Node gt = d_nm->mkNode(kind::GT, countA, countB); - Node max = d_nm->mkNode(kind::ITE, gt, countA, countB); + Node gt = d_nm->mkNode(GT, countA, countB); + Node max = d_nm->mkNode(ITE, gt, countA, countB); Node equal = count.eqNode(max); inferInfo.d_conclusion = equal; @@ -204,7 +217,7 @@ InferInfo InferenceGenerator::unionMax(Node n, Node e) InferInfo InferenceGenerator::intersection(Node n, Node e) { - Assert(n.getKind() == kind::INTERSECTION_MIN && n[0].getType().isBag()); + Assert(n.getKind() == INTERSECTION_MIN && n[0].getType().isBag()); Assert(e.getType() == n[0].getType().getBagElementType()); Node A = n[0]; @@ -216,8 +229,8 @@ InferInfo InferenceGenerator::intersection(Node n, Node e) Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); - Node lt = d_nm->mkNode(kind::LT, countA, countB); - Node min = d_nm->mkNode(kind::ITE, lt, countA, countB); + Node lt = d_nm->mkNode(LT, countA, countB); + Node min = d_nm->mkNode(ITE, lt, countA, countB); Node equal = count.eqNode(min); inferInfo.d_conclusion = equal; return inferInfo; @@ -225,7 +238,7 @@ InferInfo InferenceGenerator::intersection(Node n, Node e) InferInfo InferenceGenerator::differenceSubtract(Node n, Node e) { - Assert(n.getKind() == kind::DIFFERENCE_SUBTRACT && n[0].getType().isBag()); + Assert(n.getKind() == DIFFERENCE_SUBTRACT && n[0].getType().isBag()); Assert(e.getType() == n[0].getType().getBagElementType()); Node A = n[0]; @@ -237,9 +250,9 @@ InferInfo InferenceGenerator::differenceSubtract(Node n, Node e) Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); - Node subtract = d_nm->mkNode(kind::MINUS, countA, countB); - Node gte = d_nm->mkNode(kind::GEQ, countA, countB); - Node difference = d_nm->mkNode(kind::ITE, gte, subtract, d_zero); + Node subtract = d_nm->mkNode(MINUS, countA, countB); + Node gte = d_nm->mkNode(GEQ, countA, countB); + Node difference = d_nm->mkNode(ITE, gte, subtract, d_zero); Node equal = count.eqNode(difference); inferInfo.d_conclusion = equal; return inferInfo; @@ -247,7 +260,7 @@ InferInfo InferenceGenerator::differenceSubtract(Node n, Node e) InferInfo InferenceGenerator::differenceRemove(Node n, Node e) { - Assert(n.getKind() == kind::DIFFERENCE_REMOVE && n[0].getType().isBag()); + Assert(n.getKind() == DIFFERENCE_REMOVE && n[0].getType().isBag()); Assert(e.getType() == n[0].getType().getBagElementType()); Node A = n[0]; @@ -260,8 +273,8 @@ InferInfo InferenceGenerator::differenceRemove(Node n, Node e) Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); - Node notInB = d_nm->mkNode(kind::EQUAL, countB, d_zero); - Node difference = d_nm->mkNode(kind::ITE, notInB, countA, d_zero); + Node notInB = d_nm->mkNode(LEQ, countB, d_zero); + Node difference = d_nm->mkNode(ITE, notInB, countA, d_zero); Node equal = count.eqNode(difference); inferInfo.d_conclusion = equal; return inferInfo; @@ -269,7 +282,7 @@ InferInfo InferenceGenerator::differenceRemove(Node n, Node e) InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e) { - Assert(n.getKind() == kind::DUPLICATE_REMOVAL && n[0].getType().isBag()); + Assert(n.getKind() == DUPLICATE_REMOVAL && n[0].getType().isBag()); Assert(e.getType() == n[0].getType().getBagElementType()); Node A = n[0]; @@ -279,8 +292,8 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e) Node skolem = getSkolem(n, inferInfo); Node count = getMultiplicityTerm(e, skolem); - Node gte = d_nm->mkNode(kind::GEQ, countA, d_one); - Node ite = d_nm->mkNode(kind::ITE, gte, d_one, d_zero); + Node gte = d_nm->mkNode(GEQ, countA, d_one); + Node ite = d_nm->mkNode(ITE, gte, d_one, d_zero); Node equal = count.eqNode(ite); inferInfo.d_conclusion = equal; return inferInfo; @@ -288,14 +301,14 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e) Node InferenceGenerator::getMultiplicityTerm(Node element, Node bag) { - Node count = d_nm->mkNode(kind::BAG_COUNT, element, bag); + Node count = d_nm->mkNode(BAG_COUNT, element, bag); return count; } std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n, Node e) { - Assert(n.getKind() == kind::BAG_MAP && n[1].getType().isBag()); + Assert(n.getKind() == BAG_MAP && n[1].getType().isBag()); Assert(n[0].getType().isFunction() && n[0].getType().getArgTypes().size() == 1); Assert(e.getType() == n[0].getType().getRangeType()); @@ -316,8 +329,8 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n, Node sum = d_sm->mkSkolemFunction(SkolemFunId::BAGS_MAP_SUM, sumType, {n, e}); // (= (sum 0) 0) - Node sum_zero = d_nm->mkNode(kind::APPLY_UF, sum, d_zero); - Node baseCase = d_nm->mkNode(Kind::EQUAL, sum_zero, d_zero); + Node sum_zero = d_nm->mkNode(APPLY_UF, sum, d_zero); + Node baseCase = d_nm->mkNode(EQUAL, sum_zero, d_zero); // guess the size of the preimage of e Node preImageSize = d_sm->mkDummySkolem("preImageSize", d_nm->integerType()); @@ -325,8 +338,8 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n, // (= (sum preImageSize) (bag.count e skolem)) Node mapSkolem = getSkolem(n, inferInfo); Node countE = getMultiplicityTerm(e, mapSkolem); - Node totalSum = d_nm->mkNode(kind::APPLY_UF, sum, preImageSize); - Node totalSumEqualCountE = d_nm->mkNode(kind::EQUAL, totalSum, countE); + Node totalSum = d_nm->mkNode(APPLY_UF, sum, preImageSize); + Node totalSumEqualCountE = d_nm->mkNode(EQUAL, totalSum, countE); // (forall ((i Int)) // (let ((uf_i (uf i))) @@ -347,44 +360,42 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n, Node i = bvm->mkBoundVar<FirstIndexVarAttribute>(n, "i", d_nm->integerType()); Node j = bvm->mkBoundVar<SecondIndexVarAttribute>(n, "j", d_nm->integerType()); - Node iList = d_nm->mkNode(kind::BOUND_VAR_LIST, i); - Node jList = d_nm->mkNode(kind::BOUND_VAR_LIST, j); - Node iPlusOne = d_nm->mkNode(kind::PLUS, i, d_one); - Node iMinusOne = d_nm->mkNode(kind::MINUS, i, d_one); - Node uf_i = d_nm->mkNode(kind::APPLY_UF, uf, i); - Node uf_j = d_nm->mkNode(kind::APPLY_UF, uf, j); - Node f_uf_i = d_nm->mkNode(kind::APPLY_UF, f, uf_i); - Node uf_iPlusOne = d_nm->mkNode(kind::APPLY_UF, uf, iPlusOne); - Node uf_iMinusOne = d_nm->mkNode(kind::APPLY_UF, uf, iMinusOne); - Node interval_i = d_nm->mkNode(kind::AND, - d_nm->mkNode(kind::GEQ, i, d_one), - d_nm->mkNode(kind::LEQ, i, preImageSize)); - Node sum_i = d_nm->mkNode(kind::APPLY_UF, sum, i); - Node sum_iPlusOne = d_nm->mkNode(kind::APPLY_UF, sum, iPlusOne); - Node sum_iMinusOne = d_nm->mkNode(kind::APPLY_UF, sum, iMinusOne); - Node count_iMinusOne = d_nm->mkNode(kind::BAG_COUNT, uf_iMinusOne, A); - Node count_uf_i = d_nm->mkNode(kind::BAG_COUNT, uf_i, A); - Node inductiveCase = d_nm->mkNode( - Kind::EQUAL, sum_i, d_nm->mkNode(kind::PLUS, sum_iMinusOne, count_uf_i)); - Node f_iEqualE = d_nm->mkNode(kind::EQUAL, f_uf_i, e); - Node geqOne = d_nm->mkNode(kind::GEQ, count_uf_i, d_one); + Node iList = d_nm->mkNode(BOUND_VAR_LIST, i); + Node jList = d_nm->mkNode(BOUND_VAR_LIST, j); + Node iPlusOne = d_nm->mkNode(PLUS, i, d_one); + Node iMinusOne = d_nm->mkNode(MINUS, i, d_one); + Node uf_i = d_nm->mkNode(APPLY_UF, uf, i); + Node uf_j = d_nm->mkNode(APPLY_UF, uf, j); + Node f_uf_i = d_nm->mkNode(APPLY_UF, f, uf_i); + Node uf_iPlusOne = d_nm->mkNode(APPLY_UF, uf, iPlusOne); + Node uf_iMinusOne = d_nm->mkNode(APPLY_UF, uf, iMinusOne); + Node interval_i = d_nm->mkNode( + AND, d_nm->mkNode(GEQ, i, d_one), d_nm->mkNode(LEQ, i, preImageSize)); + Node sum_i = d_nm->mkNode(APPLY_UF, sum, i); + Node sum_iPlusOne = d_nm->mkNode(APPLY_UF, sum, iPlusOne); + Node sum_iMinusOne = d_nm->mkNode(APPLY_UF, sum, iMinusOne); + Node count_iMinusOne = d_nm->mkNode(BAG_COUNT, uf_iMinusOne, A); + Node count_uf_i = d_nm->mkNode(BAG_COUNT, uf_i, A); + Node inductiveCase = + d_nm->mkNode(EQUAL, sum_i, d_nm->mkNode(PLUS, sum_iMinusOne, count_uf_i)); + Node f_iEqualE = d_nm->mkNode(EQUAL, f_uf_i, e); + Node geqOne = d_nm->mkNode(GEQ, count_uf_i, d_one); // i < j <= preImageSize - Node interval_j = d_nm->mkNode(kind::AND, - d_nm->mkNode(kind::LT, i, j), - d_nm->mkNode(kind::LEQ, j, preImageSize)); + Node interval_j = d_nm->mkNode( + AND, d_nm->mkNode(LT, i, j), d_nm->mkNode(LEQ, j, preImageSize)); // uf(i) != uf(j) - Node uf_i_equals_uf_j = d_nm->mkNode(kind::EQUAL, uf_i, uf_j); - Node notEqual = d_nm->mkNode(kind::EQUAL, uf_i, uf_j).negate(); - Node body_j = d_nm->mkNode(kind::OR, interval_j.negate(), notEqual); + Node uf_i_equals_uf_j = d_nm->mkNode(EQUAL, uf_i, uf_j); + Node notEqual = d_nm->mkNode(EQUAL, uf_i, uf_j).negate(); + Node body_j = d_nm->mkNode(OR, interval_j.negate(), notEqual); Node forAll_j = quantifiers::BoundedIntegers::mkBoundedForall(jList, body_j); Node andNode = - d_nm->mkNode(kind::AND, {f_iEqualE, geqOne, inductiveCase, forAll_j}); - Node body_i = d_nm->mkNode(kind::OR, interval_i.negate(), andNode); + d_nm->mkNode(AND, {f_iEqualE, geqOne, inductiveCase, forAll_j}); + Node body_i = d_nm->mkNode(OR, interval_i.negate(), andNode); Node forAll_i = quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i); - Node preImageGTE_zero = d_nm->mkNode(kind::GEQ, preImageSize, d_zero); + Node preImageGTE_zero = d_nm->mkNode(GEQ, preImageSize, d_zero); Node conclusion = d_nm->mkNode( - kind::AND, {baseCase, totalSumEqualCountE, forAll_i, preImageGTE_zero}); + AND, {baseCase, totalSumEqualCountE, forAll_i, preImageGTE_zero}); inferInfo.d_conclusion = conclusion; std::map<Node, Node> m; @@ -395,7 +406,7 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n, InferInfo InferenceGenerator::mapUpwards( Node n, Node uf, Node preImageSize, Node y, Node x) { - Assert(n.getKind() == kind::BAG_MAP && n[1].getType().isBag()); + Assert(n.getKind() == BAG_MAP && n[1].getType().isBag()); Assert(n[0].getType().isFunction() && n[0].getType().getArgTypes().size() == 1); @@ -404,19 +415,16 @@ InferInfo InferenceGenerator::mapUpwards( Node A = n[1]; Node countA = getMultiplicityTerm(x, A); - Node xInA = d_nm->mkNode(kind::GEQ, countA, d_one); - Node notEqual = - d_nm->mkNode(kind::EQUAL, d_nm->mkNode(kind::APPLY_UF, f, x), y).negate(); + Node xInA = d_nm->mkNode(GEQ, countA, d_one); + Node notEqual = d_nm->mkNode(EQUAL, d_nm->mkNode(APPLY_UF, f, x), y).negate(); Node k = d_sm->mkDummySkolem("k", d_nm->integerType()); - Node inRange = d_nm->mkNode(kind::AND, - d_nm->mkNode(kind::GEQ, k, d_one), - d_nm->mkNode(kind::LEQ, k, preImageSize)); - Node equal = - d_nm->mkNode(kind::EQUAL, d_nm->mkNode(kind::APPLY_UF, uf, k), x); - Node andNode = d_nm->mkNode(kind::AND, inRange, equal); - Node orNode = d_nm->mkNode(kind::OR, notEqual, andNode); - Node implies = d_nm->mkNode(kind::IMPLIES, xInA, orNode); + Node inRange = d_nm->mkNode( + AND, d_nm->mkNode(GEQ, k, d_one), d_nm->mkNode(LEQ, k, preImageSize)); + Node equal = d_nm->mkNode(EQUAL, d_nm->mkNode(APPLY_UF, uf, k), x); + Node andNode = d_nm->mkNode(AND, inRange, equal); + Node orNode = d_nm->mkNode(OR, notEqual, andNode); + Node implies = d_nm->mkNode(IMPLIES, xInA, orNode); inferInfo.d_conclusion = implies; std::cout << "Upwards conclusion: " << inferInfo.d_conclusion << std::endl << std::endl; |