summaryrefslogtreecommitdiff
path: root/src/theory/bags/normal_form.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/bags/normal_form.cpp')
-rw-r--r--src/theory/bags/normal_form.cpp180
1 files changed, 92 insertions, 88 deletions
diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp
index 59344cf0b..12bf513b5 100644
--- a/src/theory/bags/normal_form.cpp
+++ b/src/theory/bags/normal_form.cpp
@@ -28,19 +28,19 @@ namespace bags {
bool NormalForm::isConstant(TNode n)
{
- if (n.getKind() == EMPTYBAG)
+ if (n.getKind() == BAG_EMPTY)
{
// empty bags are already normalized
return true;
}
- if (n.getKind() == MK_BAG)
+ if (n.getKind() == BAG_MAKE)
{
// see the implementation in MkBagTypeRule::computeIsConst
return n.isConst();
}
- if (n.getKind() == UNION_DISJOINT)
+ if (n.getKind() == BAG_UNION_DISJOINT)
{
- if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst()))
+ if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst()))
{
// the first child is not a constant
return false;
@@ -48,9 +48,9 @@ bool NormalForm::isConstant(TNode n)
// store the previous element to check the ordering of elements
Node previousElement = n[0][0];
Node current = n[1];
- while (current.getKind() == UNION_DISJOINT)
+ while (current.getKind() == BAG_UNION_DISJOINT)
{
- if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst()))
+ if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst()))
{
// the current element is not a constant
return false;
@@ -64,7 +64,7 @@ bool NormalForm::isConstant(TNode n)
current = current[1];
}
// check last element
- if (!(current.getKind() == kind::MK_BAG && current.isConst()))
+ if (!(current.getKind() == kind::BAG_MAKE && current.isConst()))
{
// the last element is not a constant
return false;
@@ -77,7 +77,7 @@ bool NormalForm::isConstant(TNode n)
return true;
}
- // only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be
+ // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be
// constants
return false;
}
@@ -97,14 +97,14 @@ Node NormalForm::evaluate(TNode n)
}
switch (n.getKind())
{
- case MK_BAG: return evaluateMakeBag(n);
+ case BAG_MAKE: return evaluateMakeBag(n);
case BAG_COUNT: return evaluateBagCount(n);
- case DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
- case UNION_DISJOINT: return evaluateUnionDisjoint(n);
- case UNION_MAX: return evaluateUnionMax(n);
- case INTERSECTION_MIN: return evaluateIntersectionMin(n);
- case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
- case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
+ case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
+ case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n);
+ case BAG_UNION_MAX: return evaluateUnionMax(n);
+ case BAG_INTER_MIN: return evaluateIntersectionMin(n);
+ case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
+ case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
case BAG_CARD: return evaluateCard(n);
case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
case BAG_FROM_SET: return evaluateFromSet(n);
@@ -172,19 +172,19 @@ std::map<Node, Rational> NormalForm::getBagElements(TNode n)
Assert(n.isConst()) << "node " << n << " is not in a normal form"
<< std::endl;
std::map<Node, Rational> elements;
- if (n.getKind() == EMPTYBAG)
+ if (n.getKind() == BAG_EMPTY)
{
return elements;
}
- while (n.getKind() == kind::UNION_DISJOINT)
+ while (n.getKind() == kind::BAG_UNION_DISJOINT)
{
- Assert(n[0].getKind() == kind::MK_BAG);
+ Assert(n[0].getKind() == kind::BAG_MAKE);
Node element = n[0][0];
Rational count = n[0][1].getConst<Rational>();
elements[element] = count;
n = n[1];
}
- Assert(n.getKind() == kind::MK_BAG);
+ Assert(n.getKind() == kind::BAG_MAKE);
Node lastElement = n[0];
Rational lastCount = n[1].getConst<Rational>();
elements[lastElement] = lastCount;
@@ -202,13 +202,15 @@ Node NormalForm::constructConstantBagFromElements(
}
TypeNode elementType = t.getBagElementType();
std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
- Node bag =
- nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
+ Node bag = nm->mkBag(elementType,
+ it->first,
+ nm->mkConst<Rational>(CONST_RATIONAL, it->second));
while (++it != elements.rend())
{
- Node n =
- nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
- bag = nm->mkNode(UNION_DISJOINT, n, bag);
+ Node n = nm->mkBag(elementType,
+ it->first,
+ nm->mkConst<Rational>(CONST_RATIONAL, it->second));
+ bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
}
return bag;
}
@@ -228,7 +230,7 @@ Node NormalForm::constructBagFromElements(TypeNode t,
while (++it != elements.rend())
{
Node n = nm->mkBag(elementType, it->first, it->second);
- bag = nm->mkNode(UNION_DISJOINT, n, bag);
+ bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
}
return bag;
}
@@ -237,7 +239,7 @@ Node NormalForm::evaluateMakeBag(TNode n)
{
// the case where n is const should be handled earlier.
// here we handle the case where the multiplicity is zero or negative
- Assert(n.getKind() == MK_BAG && !n.isConst()
+ Assert(n.getKind() == BAG_MAKE && !n.isConst()
&& n[1].getConst<Rational>().sgn() < 1);
Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
return emptybag;
@@ -248,11 +250,11 @@ Node NormalForm::evaluateBagCount(TNode n)
Assert(n.getKind() == BAG_COUNT);
// Examples
// --------
- // - (bag.count "x" (emptybag String)) = 0
- // - (bag.count "x" (mkBag "y" 5)) = 0
- // - (bag.count "x" (mkBag "x" 4)) = 4
- // - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4
- // - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0
+ // - (bag.count "x" (as bag.empty (Bag String))) = 0
+ // - (bag.count "x" (bag "y" 5)) = 0
+ // - (bag.count "x" (bag "x" 4)) = 4
+ // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4
+ // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0
std::map<Node, Rational> elements = getBagElements(n[1]);
std::map<Node, Rational>::iterator it = elements.find(n[0]);
@@ -260,22 +262,23 @@ Node NormalForm::evaluateBagCount(TNode n)
NodeManager* nm = NodeManager::currentNM();
if (it != elements.end())
{
- Node count = nm->mkConst(it->second);
+ Node count = nm->mkConst(CONST_RATIONAL, it->second);
return count;
}
- return nm->mkConst(Rational(0));
+ return nm->mkConst(CONST_RATIONAL, Rational(0));
}
Node NormalForm::evaluateDuplicateRemoval(TNode n)
{
- Assert(n.getKind() == DUPLICATE_REMOVAL);
+ Assert(n.getKind() == BAG_DUPLICATE_REMOVAL);
// Examples
// --------
- // - (duplicate_removal (emptybag String)) = (emptybag String)
- // - (duplicate_removal (mkBag "x" 4)) = (emptybag "x" 1)
- // - (duplicate_removal (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
- // (disjoint_union (mkBag "x" 1) (mkBag "y" 1)
+ // - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag
+ // String))
+ // - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1)
+ // - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
+ // (bag.disjoint_union (bag "x" 1) (bag "y" 1)
std::map<Node, Rational> oldElements = getBagElements(n[0]);
// copy elements from the old bag
@@ -292,16 +295,16 @@ Node NormalForm::evaluateDuplicateRemoval(TNode n)
Node NormalForm::evaluateUnionDisjoint(TNode n)
{
- Assert(n.getKind() == UNION_DISJOINT);
+ Assert(n.getKind() == BAG_UNION_DISJOINT);
// Example
// -------
- // input: (union_disjoint A B)
- // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
- // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // input: (bag.union_disjoint A B)
+ // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+ // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
// output:
- // (union_disjoint A B)
- // where A = (MK_BAG "x" 7)
- // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
+ // (bag.union_disjoint A B)
+ // where A = (bag "x" 7)
+ // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
auto equal = [](std::map<Node, Rational>& elements,
std::map<Node, Rational>::const_iterator& itA,
@@ -352,16 +355,16 @@ Node NormalForm::evaluateUnionDisjoint(TNode n)
Node NormalForm::evaluateUnionMax(TNode n)
{
- Assert(n.getKind() == UNION_MAX);
+ Assert(n.getKind() == BAG_UNION_MAX);
// Example
// -------
- // input: (union_max A B)
- // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
- // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // input: (bag.union_max A B)
+ // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+ // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
// output:
- // (union_disjoint A B)
- // where A = (MK_BAG "x" 4)
- // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
+ // (bag.union_disjoint A B)
+ // where A = (bag "x" 4)
+ // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
auto equal = [](std::map<Node, Rational>& elements,
std::map<Node, Rational>::const_iterator& itA,
@@ -412,14 +415,14 @@ Node NormalForm::evaluateUnionMax(TNode n)
Node NormalForm::evaluateIntersectionMin(TNode n)
{
- Assert(n.getKind() == INTERSECTION_MIN);
+ Assert(n.getKind() == BAG_INTER_MIN);
// Example
// -------
- // input: (intersectionMin A B)
- // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
- // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // input: (bag.inter_min A B)
+ // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+ // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
// output:
- // (MK_BAG "x" 3)
+ // (bag "x" 3)
auto equal = [](std::map<Node, Rational>& elements,
std::map<Node, Rational>::const_iterator& itA,
@@ -458,14 +461,14 @@ Node NormalForm::evaluateIntersectionMin(TNode n)
Node NormalForm::evaluateDifferenceSubtract(TNode n)
{
- Assert(n.getKind() == DIFFERENCE_SUBTRACT);
+ Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT);
// Example
// -------
- // input: (difference_subtract A B)
- // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
- // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // input: (bag.difference_subtract A B)
+ // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+ // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
// output:
- // (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2))
+ // (bag.union_disjoint (bag "x" 1) (bag "z" 2))
auto equal = [](std::map<Node, Rational>& elements,
std::map<Node, Rational>::const_iterator& itA,
@@ -510,14 +513,14 @@ Node NormalForm::evaluateDifferenceSubtract(TNode n)
Node NormalForm::evaluateDifferenceRemove(TNode n)
{
- Assert(n.getKind() == DIFFERENCE_REMOVE);
+ Assert(n.getKind() == BAG_DIFFERENCE_REMOVE);
// Example
// -------
- // input: (difference_subtract A B)
- // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
- // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // input: (bag.difference_remove A B)
+ // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+ // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
// output:
- // (MK_BAG "z" 2)
+ // (bag "z" 2)
auto equal = [](std::map<Node, Rational>& elements,
std::map<Node, Rational>::const_iterator& itA,
@@ -564,9 +567,9 @@ Node NormalForm::evaluateChoose(TNode n)
Assert(n.getKind() == BAG_CHOOSE);
// Examples
// --------
- // - (bag.choose (MK_BAG "x" 4)) = "x"
+ // - (bag.choose (bag "x" 4)) = "x"
- if (n[0].getKind() == MK_BAG)
+ if (n[0].getKind() == BAG_MAKE)
{
return n[0][0];
}
@@ -578,9 +581,9 @@ Node NormalForm::evaluateCard(TNode n)
Assert(n.getKind() == BAG_CARD);
// Examples
// --------
- // - (card (emptyBag String)) = 0
- // - (choose (MK_BAG "x" 4)) = 4
- // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5
+ // - (card (as bag.empty (Bag String))) = 0
+ // - (bag.choose (bag "x" 4)) = 4
+ // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5
std::map<Node, Rational> elements = getBagElements(n[0]);
Rational sum(0);
@@ -590,7 +593,7 @@ Node NormalForm::evaluateCard(TNode n)
}
NodeManager* nm = NodeManager::currentNM();
- Node sumNode = nm->mkConst(sum);
+ Node sumNode = nm->mkConst(CONST_RATIONAL, sum);
return sumNode;
}
@@ -599,12 +602,13 @@ Node NormalForm::evaluateIsSingleton(TNode n)
Assert(n.getKind() == BAG_IS_SINGLETON);
// Examples
// --------
- // - (bag.is_singleton (emptyBag String)) = false
- // - (bag.is_singleton (MK_BAG "x" 1)) = true
- // - (bag.is_singleton (MK_BAG "x" 4)) = false
- // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false
+ // - (bag.is_singleton (as bag.empty (Bag String))) = false
+ // - (bag.is_singleton (bag "x" 1)) = true
+ // - (bag.is_singleton (bag "x" 4)) = false
+ // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1)))
+ // = false
- if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne())
+ if (n[0].getKind() == BAG_MAKE && n[0][1].getConst<Rational>().isOne())
{
return NodeManager::currentNM()->mkConst(true);
}
@@ -617,10 +621,10 @@ Node NormalForm::evaluateFromSet(TNode n)
// Examples
// --------
- // - (bag.from_set (set.empty String)) = (emptybag String)
- // - (bag.from_set (singleton "x")) = (mkBag "x" 1)
- // - (bag.from_set (union (singleton "x") (singleton "y"))) =
- // (disjoint_union (mkBag "x" 1) (mkBag "y" 1))
+ // - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String))
+ // - (bag.from_set (set.singleton "x")) = (bag "x" 1)
+ // - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) =
+ // (bag.disjoint_union (bag "x" 1) (bag "y" 1))
NodeManager* nm = NodeManager::currentNM();
std::set<Node> setElements =
@@ -642,10 +646,10 @@ Node NormalForm::evaluateToSet(TNode n)
// Examples
// --------
- // - (bag.to_set (emptybag String)) = (set.empty String)
- // - (bag.to_set (mkBag "x" 4)) = (singleton "x")
- // - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
- // (union (singleton "x") (singleton "y")))
+ // - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String))
+ // - (bag.to_set (bag "x" 4)) = (set.singleton "x")
+ // - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
+ // (set.union (set.singleton "x") (set.singleton "y")))
NodeManager* nm = NodeManager::currentNM();
std::map<Node, Rational> bagElements = getBagElements(n[0]);
@@ -667,8 +671,8 @@ Node NormalForm::evaluateBagMap(TNode n)
// Examples
// --------
// - (bag.map ((lambda ((x String)) "z")
- // (union_disjoint (bag "a" 2) (bag "b" 3)) =
- // (union_disjoint
+ // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) =
+ // (bag.union_disjoint
// (bag ((lambda ((x String)) "z") "a") 2)
// (bag ((lambda ((x String)) "z") "b") 3)) =
// (bag "z" 5)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback