summaryrefslogtreecommitdiff
path: root/src/theory/bags/inference_generator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/bags/inference_generator.cpp')
-rw-r--r--src/theory/bags/inference_generator.cpp170
1 files changed, 89 insertions, 81 deletions
diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp
index 734572f7c..e88a7e0ca 100644
--- a/src/theory/bags/inference_generator.cpp
+++ b/src/theory/bags/inference_generator.cpp
@@ -24,6 +24,8 @@
#include "theory/uf/equality_engine.h"
#include "util/rational.h"
+using namespace cvc5::kind;
+
namespace cvc5 {
namespace theory {
namespace bags {
@@ -44,38 +46,49 @@ InferInfo InferenceGenerator::nonNegativeCount(Node n, Node e)
Assert(e.getType() == n.getType().getBagElementType());
InferInfo inferInfo(d_im, InferenceId::BAGS_NON_NEGATIVE_COUNT);
- Node count = d_nm->mkNode(kind::BAG_COUNT, e, n);
+ Node count = d_nm->mkNode(BAG_COUNT, e, n);
- Node gte = d_nm->mkNode(kind::GEQ, count, d_zero);
+ Node gte = d_nm->mkNode(GEQ, count, d_zero);
inferInfo.d_conclusion = gte;
return inferInfo;
}
InferInfo InferenceGenerator::mkBag(Node n, Node e)
{
- Assert(n.getKind() == kind::MK_BAG);
+ Assert(n.getKind() == MK_BAG);
Assert(e.getType() == n.getType().getBagElementType());
- if (n[0] == e)
+ Node x = n[0];
+ Node c = n[1];
+ Node geq = d_nm->mkNode(GEQ, c, d_one);
+ if (d_state->areEqual(e, x))
+ {
+ // (= (bag.count e skolem) (ite (>= c 1) c 0)))
+ InferInfo inferInfo(d_im, InferenceId::BAGS_MK_BAG_SAME_ELEMENT);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
+ Node ite = d_nm->mkNode(ITE, geq, c, d_zero);
+ inferInfo.d_conclusion = count.eqNode(ite);
+ return inferInfo;
+ }
+ if (d_state->areDisequal(e, x))
{
- // TODO issue #78: refactor this with BagRewriter
- // (=> true (= (bag.count e (bag e c)) c))
+ //(= (bag.count e skolem) 0))
InferInfo inferInfo(d_im, InferenceId::BAGS_MK_BAG_SAME_ELEMENT);
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
- inferInfo.d_conclusion = count.eqNode(n[1]);
+ inferInfo.d_conclusion = count.eqNode(d_zero);
return inferInfo;
}
else
{
- // (=>
- // true
- // (= (bag.count e (bag x c)) (ite (= e x) c 0)))
+ // (= (bag.count e skolem) (ite (and (= e x) (>= c 1)) c 0)))
InferInfo inferInfo(d_im, InferenceId::BAGS_MK_BAG);
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
- Node same = d_nm->mkNode(kind::EQUAL, n[0], e);
- Node ite = d_nm->mkNode(kind::ITE, same, n[1], d_zero);
+ Node same = d_nm->mkNode(EQUAL, e, x);
+ Node andNode = same.andNode(geq);
+ Node ite = d_nm->mkNode(ITE, andNode, c, d_zero);
Node equal = count.eqNode(ite);
inferInfo.d_conclusion = equal;
return inferInfo;
@@ -110,7 +123,7 @@ typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
InferInfo InferenceGenerator::bagDisequality(Node n)
{
- Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag());
+ Assert(n.getKind() == EQUAL && n[0].getType().isBag());
Node A = n[0];
Node B = n[1];
@@ -145,7 +158,7 @@ Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo)
InferInfo InferenceGenerator::empty(Node n, Node e)
{
- Assert(n.getKind() == kind::EMPTYBAG);
+ Assert(n.getKind() == EMPTYBAG);
Assert(e.getType() == n.getType().getBagElementType());
InferInfo inferInfo(d_im, InferenceId::BAGS_EMPTY);
@@ -159,7 +172,7 @@ InferInfo InferenceGenerator::empty(Node n, Node e)
InferInfo InferenceGenerator::unionDisjoint(Node n, Node e)
{
- Assert(n.getKind() == kind::UNION_DISJOINT && n[0].getType().isBag());
+ Assert(n.getKind() == UNION_DISJOINT && n[0].getType().isBag());
Assert(e.getType() == n[0].getType().getBagElementType());
Node A = n[0];
@@ -172,7 +185,7 @@ InferInfo InferenceGenerator::unionDisjoint(Node n, Node e)
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
- Node sum = d_nm->mkNode(kind::PLUS, countA, countB);
+ Node sum = d_nm->mkNode(PLUS, countA, countB);
Node equal = count.eqNode(sum);
inferInfo.d_conclusion = equal;
@@ -181,7 +194,7 @@ InferInfo InferenceGenerator::unionDisjoint(Node n, Node e)
InferInfo InferenceGenerator::unionMax(Node n, Node e)
{
- Assert(n.getKind() == kind::UNION_MAX && n[0].getType().isBag());
+ Assert(n.getKind() == UNION_MAX && n[0].getType().isBag());
Assert(e.getType() == n[0].getType().getBagElementType());
Node A = n[0];
@@ -194,8 +207,8 @@ InferInfo InferenceGenerator::unionMax(Node n, Node e)
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
- Node gt = d_nm->mkNode(kind::GT, countA, countB);
- Node max = d_nm->mkNode(kind::ITE, gt, countA, countB);
+ Node gt = d_nm->mkNode(GT, countA, countB);
+ Node max = d_nm->mkNode(ITE, gt, countA, countB);
Node equal = count.eqNode(max);
inferInfo.d_conclusion = equal;
@@ -204,7 +217,7 @@ InferInfo InferenceGenerator::unionMax(Node n, Node e)
InferInfo InferenceGenerator::intersection(Node n, Node e)
{
- Assert(n.getKind() == kind::INTERSECTION_MIN && n[0].getType().isBag());
+ Assert(n.getKind() == INTERSECTION_MIN && n[0].getType().isBag());
Assert(e.getType() == n[0].getType().getBagElementType());
Node A = n[0];
@@ -216,8 +229,8 @@ InferInfo InferenceGenerator::intersection(Node n, Node e)
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
- Node lt = d_nm->mkNode(kind::LT, countA, countB);
- Node min = d_nm->mkNode(kind::ITE, lt, countA, countB);
+ Node lt = d_nm->mkNode(LT, countA, countB);
+ Node min = d_nm->mkNode(ITE, lt, countA, countB);
Node equal = count.eqNode(min);
inferInfo.d_conclusion = equal;
return inferInfo;
@@ -225,7 +238,7 @@ InferInfo InferenceGenerator::intersection(Node n, Node e)
InferInfo InferenceGenerator::differenceSubtract(Node n, Node e)
{
- Assert(n.getKind() == kind::DIFFERENCE_SUBTRACT && n[0].getType().isBag());
+ Assert(n.getKind() == DIFFERENCE_SUBTRACT && n[0].getType().isBag());
Assert(e.getType() == n[0].getType().getBagElementType());
Node A = n[0];
@@ -237,9 +250,9 @@ InferInfo InferenceGenerator::differenceSubtract(Node n, Node e)
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
- Node subtract = d_nm->mkNode(kind::MINUS, countA, countB);
- Node gte = d_nm->mkNode(kind::GEQ, countA, countB);
- Node difference = d_nm->mkNode(kind::ITE, gte, subtract, d_zero);
+ Node subtract = d_nm->mkNode(MINUS, countA, countB);
+ Node gte = d_nm->mkNode(GEQ, countA, countB);
+ Node difference = d_nm->mkNode(ITE, gte, subtract, d_zero);
Node equal = count.eqNode(difference);
inferInfo.d_conclusion = equal;
return inferInfo;
@@ -247,7 +260,7 @@ InferInfo InferenceGenerator::differenceSubtract(Node n, Node e)
InferInfo InferenceGenerator::differenceRemove(Node n, Node e)
{
- Assert(n.getKind() == kind::DIFFERENCE_REMOVE && n[0].getType().isBag());
+ Assert(n.getKind() == DIFFERENCE_REMOVE && n[0].getType().isBag());
Assert(e.getType() == n[0].getType().getBagElementType());
Node A = n[0];
@@ -260,8 +273,8 @@ InferInfo InferenceGenerator::differenceRemove(Node n, Node e)
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
- Node notInB = d_nm->mkNode(kind::EQUAL, countB, d_zero);
- Node difference = d_nm->mkNode(kind::ITE, notInB, countA, d_zero);
+ Node notInB = d_nm->mkNode(LEQ, countB, d_zero);
+ Node difference = d_nm->mkNode(ITE, notInB, countA, d_zero);
Node equal = count.eqNode(difference);
inferInfo.d_conclusion = equal;
return inferInfo;
@@ -269,7 +282,7 @@ InferInfo InferenceGenerator::differenceRemove(Node n, Node e)
InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e)
{
- Assert(n.getKind() == kind::DUPLICATE_REMOVAL && n[0].getType().isBag());
+ Assert(n.getKind() == DUPLICATE_REMOVAL && n[0].getType().isBag());
Assert(e.getType() == n[0].getType().getBagElementType());
Node A = n[0];
@@ -279,8 +292,8 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e)
Node skolem = getSkolem(n, inferInfo);
Node count = getMultiplicityTerm(e, skolem);
- Node gte = d_nm->mkNode(kind::GEQ, countA, d_one);
- Node ite = d_nm->mkNode(kind::ITE, gte, d_one, d_zero);
+ Node gte = d_nm->mkNode(GEQ, countA, d_one);
+ Node ite = d_nm->mkNode(ITE, gte, d_one, d_zero);
Node equal = count.eqNode(ite);
inferInfo.d_conclusion = equal;
return inferInfo;
@@ -288,14 +301,14 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e)
Node InferenceGenerator::getMultiplicityTerm(Node element, Node bag)
{
- Node count = d_nm->mkNode(kind::BAG_COUNT, element, bag);
+ Node count = d_nm->mkNode(BAG_COUNT, element, bag);
return count;
}
std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n,
Node e)
{
- Assert(n.getKind() == kind::BAG_MAP && n[1].getType().isBag());
+ Assert(n.getKind() == BAG_MAP && n[1].getType().isBag());
Assert(n[0].getType().isFunction()
&& n[0].getType().getArgTypes().size() == 1);
Assert(e.getType() == n[0].getType().getRangeType());
@@ -316,8 +329,8 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n,
Node sum = d_sm->mkSkolemFunction(SkolemFunId::BAGS_MAP_SUM, sumType, {n, e});
// (= (sum 0) 0)
- Node sum_zero = d_nm->mkNode(kind::APPLY_UF, sum, d_zero);
- Node baseCase = d_nm->mkNode(Kind::EQUAL, sum_zero, d_zero);
+ Node sum_zero = d_nm->mkNode(APPLY_UF, sum, d_zero);
+ Node baseCase = d_nm->mkNode(EQUAL, sum_zero, d_zero);
// guess the size of the preimage of e
Node preImageSize = d_sm->mkDummySkolem("preImageSize", d_nm->integerType());
@@ -325,8 +338,8 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n,
// (= (sum preImageSize) (bag.count e skolem))
Node mapSkolem = getSkolem(n, inferInfo);
Node countE = getMultiplicityTerm(e, mapSkolem);
- Node totalSum = d_nm->mkNode(kind::APPLY_UF, sum, preImageSize);
- Node totalSumEqualCountE = d_nm->mkNode(kind::EQUAL, totalSum, countE);
+ Node totalSum = d_nm->mkNode(APPLY_UF, sum, preImageSize);
+ Node totalSumEqualCountE = d_nm->mkNode(EQUAL, totalSum, countE);
// (forall ((i Int))
// (let ((uf_i (uf i)))
@@ -347,44 +360,42 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n,
Node i = bvm->mkBoundVar<FirstIndexVarAttribute>(n, "i", d_nm->integerType());
Node j =
bvm->mkBoundVar<SecondIndexVarAttribute>(n, "j", d_nm->integerType());
- Node iList = d_nm->mkNode(kind::BOUND_VAR_LIST, i);
- Node jList = d_nm->mkNode(kind::BOUND_VAR_LIST, j);
- Node iPlusOne = d_nm->mkNode(kind::PLUS, i, d_one);
- Node iMinusOne = d_nm->mkNode(kind::MINUS, i, d_one);
- Node uf_i = d_nm->mkNode(kind::APPLY_UF, uf, i);
- Node uf_j = d_nm->mkNode(kind::APPLY_UF, uf, j);
- Node f_uf_i = d_nm->mkNode(kind::APPLY_UF, f, uf_i);
- Node uf_iPlusOne = d_nm->mkNode(kind::APPLY_UF, uf, iPlusOne);
- Node uf_iMinusOne = d_nm->mkNode(kind::APPLY_UF, uf, iMinusOne);
- Node interval_i = d_nm->mkNode(kind::AND,
- d_nm->mkNode(kind::GEQ, i, d_one),
- d_nm->mkNode(kind::LEQ, i, preImageSize));
- Node sum_i = d_nm->mkNode(kind::APPLY_UF, sum, i);
- Node sum_iPlusOne = d_nm->mkNode(kind::APPLY_UF, sum, iPlusOne);
- Node sum_iMinusOne = d_nm->mkNode(kind::APPLY_UF, sum, iMinusOne);
- Node count_iMinusOne = d_nm->mkNode(kind::BAG_COUNT, uf_iMinusOne, A);
- Node count_uf_i = d_nm->mkNode(kind::BAG_COUNT, uf_i, A);
- Node inductiveCase = d_nm->mkNode(
- Kind::EQUAL, sum_i, d_nm->mkNode(kind::PLUS, sum_iMinusOne, count_uf_i));
- Node f_iEqualE = d_nm->mkNode(kind::EQUAL, f_uf_i, e);
- Node geqOne = d_nm->mkNode(kind::GEQ, count_uf_i, d_one);
+ Node iList = d_nm->mkNode(BOUND_VAR_LIST, i);
+ Node jList = d_nm->mkNode(BOUND_VAR_LIST, j);
+ Node iPlusOne = d_nm->mkNode(PLUS, i, d_one);
+ Node iMinusOne = d_nm->mkNode(MINUS, i, d_one);
+ Node uf_i = d_nm->mkNode(APPLY_UF, uf, i);
+ Node uf_j = d_nm->mkNode(APPLY_UF, uf, j);
+ Node f_uf_i = d_nm->mkNode(APPLY_UF, f, uf_i);
+ Node uf_iPlusOne = d_nm->mkNode(APPLY_UF, uf, iPlusOne);
+ Node uf_iMinusOne = d_nm->mkNode(APPLY_UF, uf, iMinusOne);
+ Node interval_i = d_nm->mkNode(
+ AND, d_nm->mkNode(GEQ, i, d_one), d_nm->mkNode(LEQ, i, preImageSize));
+ Node sum_i = d_nm->mkNode(APPLY_UF, sum, i);
+ Node sum_iPlusOne = d_nm->mkNode(APPLY_UF, sum, iPlusOne);
+ Node sum_iMinusOne = d_nm->mkNode(APPLY_UF, sum, iMinusOne);
+ Node count_iMinusOne = d_nm->mkNode(BAG_COUNT, uf_iMinusOne, A);
+ Node count_uf_i = d_nm->mkNode(BAG_COUNT, uf_i, A);
+ Node inductiveCase =
+ d_nm->mkNode(EQUAL, sum_i, d_nm->mkNode(PLUS, sum_iMinusOne, count_uf_i));
+ Node f_iEqualE = d_nm->mkNode(EQUAL, f_uf_i, e);
+ Node geqOne = d_nm->mkNode(GEQ, count_uf_i, d_one);
// i < j <= preImageSize
- Node interval_j = d_nm->mkNode(kind::AND,
- d_nm->mkNode(kind::LT, i, j),
- d_nm->mkNode(kind::LEQ, j, preImageSize));
+ Node interval_j = d_nm->mkNode(
+ AND, d_nm->mkNode(LT, i, j), d_nm->mkNode(LEQ, j, preImageSize));
// uf(i) != uf(j)
- Node uf_i_equals_uf_j = d_nm->mkNode(kind::EQUAL, uf_i, uf_j);
- Node notEqual = d_nm->mkNode(kind::EQUAL, uf_i, uf_j).negate();
- Node body_j = d_nm->mkNode(kind::OR, interval_j.negate(), notEqual);
+ Node uf_i_equals_uf_j = d_nm->mkNode(EQUAL, uf_i, uf_j);
+ Node notEqual = d_nm->mkNode(EQUAL, uf_i, uf_j).negate();
+ Node body_j = d_nm->mkNode(OR, interval_j.negate(), notEqual);
Node forAll_j = quantifiers::BoundedIntegers::mkBoundedForall(jList, body_j);
Node andNode =
- d_nm->mkNode(kind::AND, {f_iEqualE, geqOne, inductiveCase, forAll_j});
- Node body_i = d_nm->mkNode(kind::OR, interval_i.negate(), andNode);
+ d_nm->mkNode(AND, {f_iEqualE, geqOne, inductiveCase, forAll_j});
+ Node body_i = d_nm->mkNode(OR, interval_i.negate(), andNode);
Node forAll_i = quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i);
- Node preImageGTE_zero = d_nm->mkNode(kind::GEQ, preImageSize, d_zero);
+ Node preImageGTE_zero = d_nm->mkNode(GEQ, preImageSize, d_zero);
Node conclusion = d_nm->mkNode(
- kind::AND, {baseCase, totalSumEqualCountE, forAll_i, preImageGTE_zero});
+ AND, {baseCase, totalSumEqualCountE, forAll_i, preImageGTE_zero});
inferInfo.d_conclusion = conclusion;
std::map<Node, Node> m;
@@ -395,7 +406,7 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n,
InferInfo InferenceGenerator::mapUpwards(
Node n, Node uf, Node preImageSize, Node y, Node x)
{
- Assert(n.getKind() == kind::BAG_MAP && n[1].getType().isBag());
+ Assert(n.getKind() == BAG_MAP && n[1].getType().isBag());
Assert(n[0].getType().isFunction()
&& n[0].getType().getArgTypes().size() == 1);
@@ -404,19 +415,16 @@ InferInfo InferenceGenerator::mapUpwards(
Node A = n[1];
Node countA = getMultiplicityTerm(x, A);
- Node xInA = d_nm->mkNode(kind::GEQ, countA, d_one);
- Node notEqual =
- d_nm->mkNode(kind::EQUAL, d_nm->mkNode(kind::APPLY_UF, f, x), y).negate();
+ Node xInA = d_nm->mkNode(GEQ, countA, d_one);
+ Node notEqual = d_nm->mkNode(EQUAL, d_nm->mkNode(APPLY_UF, f, x), y).negate();
Node k = d_sm->mkDummySkolem("k", d_nm->integerType());
- Node inRange = d_nm->mkNode(kind::AND,
- d_nm->mkNode(kind::GEQ, k, d_one),
- d_nm->mkNode(kind::LEQ, k, preImageSize));
- Node equal =
- d_nm->mkNode(kind::EQUAL, d_nm->mkNode(kind::APPLY_UF, uf, k), x);
- Node andNode = d_nm->mkNode(kind::AND, inRange, equal);
- Node orNode = d_nm->mkNode(kind::OR, notEqual, andNode);
- Node implies = d_nm->mkNode(kind::IMPLIES, xInA, orNode);
+ Node inRange = d_nm->mkNode(
+ AND, d_nm->mkNode(GEQ, k, d_one), d_nm->mkNode(LEQ, k, preImageSize));
+ Node equal = d_nm->mkNode(EQUAL, d_nm->mkNode(APPLY_UF, uf, k), x);
+ Node andNode = d_nm->mkNode(AND, inRange, equal);
+ Node orNode = d_nm->mkNode(OR, notEqual, andNode);
+ Node implies = d_nm->mkNode(IMPLIES, xInA, orNode);
inferInfo.d_conclusion = implies;
std::cout << "Upwards conclusion: " << inferInfo.d_conclusion << std::endl
<< std::endl;
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback