diff options
Diffstat (limited to 'src/theory/bags/normal_form.cpp')
-rw-r--r-- | src/theory/bags/normal_form.cpp | 180 |
1 files changed, 92 insertions, 88 deletions
diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp index 59344cf0b..12bf513b5 100644 --- a/src/theory/bags/normal_form.cpp +++ b/src/theory/bags/normal_form.cpp @@ -28,19 +28,19 @@ namespace bags { bool NormalForm::isConstant(TNode n) { - if (n.getKind() == EMPTYBAG) + if (n.getKind() == BAG_EMPTY) { // empty bags are already normalized return true; } - if (n.getKind() == MK_BAG) + if (n.getKind() == BAG_MAKE) { // see the implementation in MkBagTypeRule::computeIsConst return n.isConst(); } - if (n.getKind() == UNION_DISJOINT) + if (n.getKind() == BAG_UNION_DISJOINT) { - if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst())) + if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst())) { // the first child is not a constant return false; @@ -48,9 +48,9 @@ bool NormalForm::isConstant(TNode n) // store the previous element to check the ordering of elements Node previousElement = n[0][0]; Node current = n[1]; - while (current.getKind() == UNION_DISJOINT) + while (current.getKind() == BAG_UNION_DISJOINT) { - if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst())) + if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst())) { // the current element is not a constant return false; @@ -64,7 +64,7 @@ bool NormalForm::isConstant(TNode n) current = current[1]; } // check last element - if (!(current.getKind() == kind::MK_BAG && current.isConst())) + if (!(current.getKind() == kind::BAG_MAKE && current.isConst())) { // the last element is not a constant return false; @@ -77,7 +77,7 @@ bool NormalForm::isConstant(TNode n) return true; } - // only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be + // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be // constants return false; } @@ -97,14 +97,14 @@ Node NormalForm::evaluate(TNode n) } switch (n.getKind()) { - case MK_BAG: return evaluateMakeBag(n); + case BAG_MAKE: return evaluateMakeBag(n); case BAG_COUNT: return evaluateBagCount(n); - case DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n); - case UNION_DISJOINT: return evaluateUnionDisjoint(n); - case UNION_MAX: return evaluateUnionMax(n); - case INTERSECTION_MIN: return evaluateIntersectionMin(n); - case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n); - case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n); + case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n); + case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n); + case BAG_UNION_MAX: return evaluateUnionMax(n); + case BAG_INTER_MIN: return evaluateIntersectionMin(n); + case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n); + case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n); case BAG_CARD: return evaluateCard(n); case BAG_IS_SINGLETON: return evaluateIsSingleton(n); case BAG_FROM_SET: return evaluateFromSet(n); @@ -172,19 +172,19 @@ std::map<Node, Rational> NormalForm::getBagElements(TNode n) Assert(n.isConst()) << "node " << n << " is not in a normal form" << std::endl; std::map<Node, Rational> elements; - if (n.getKind() == EMPTYBAG) + if (n.getKind() == BAG_EMPTY) { return elements; } - while (n.getKind() == kind::UNION_DISJOINT) + while (n.getKind() == kind::BAG_UNION_DISJOINT) { - Assert(n[0].getKind() == kind::MK_BAG); + Assert(n[0].getKind() == kind::BAG_MAKE); Node element = n[0][0]; Rational count = n[0][1].getConst<Rational>(); elements[element] = count; n = n[1]; } - Assert(n.getKind() == kind::MK_BAG); + Assert(n.getKind() == kind::BAG_MAKE); Node lastElement = n[0]; Rational lastCount = n[1].getConst<Rational>(); elements[lastElement] = lastCount; @@ -202,13 +202,15 @@ Node NormalForm::constructConstantBagFromElements( } TypeNode elementType = t.getBagElementType(); std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin(); - Node bag = - nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second)); + Node bag = nm->mkBag(elementType, + it->first, + nm->mkConst<Rational>(CONST_RATIONAL, it->second)); while (++it != elements.rend()) { - Node n = - nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second)); - bag = nm->mkNode(UNION_DISJOINT, n, bag); + Node n = nm->mkBag(elementType, + it->first, + nm->mkConst<Rational>(CONST_RATIONAL, it->second)); + bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag); } return bag; } @@ -228,7 +230,7 @@ Node NormalForm::constructBagFromElements(TypeNode t, while (++it != elements.rend()) { Node n = nm->mkBag(elementType, it->first, it->second); - bag = nm->mkNode(UNION_DISJOINT, n, bag); + bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag); } return bag; } @@ -237,7 +239,7 @@ Node NormalForm::evaluateMakeBag(TNode n) { // the case where n is const should be handled earlier. // here we handle the case where the multiplicity is zero or negative - Assert(n.getKind() == MK_BAG && !n.isConst() + Assert(n.getKind() == BAG_MAKE && !n.isConst() && n[1].getConst<Rational>().sgn() < 1); Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType())); return emptybag; @@ -248,11 +250,11 @@ Node NormalForm::evaluateBagCount(TNode n) Assert(n.getKind() == BAG_COUNT); // Examples // -------- - // - (bag.count "x" (emptybag String)) = 0 - // - (bag.count "x" (mkBag "y" 5)) = 0 - // - (bag.count "x" (mkBag "x" 4)) = 4 - // - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4 - // - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0 + // - (bag.count "x" (as bag.empty (Bag String))) = 0 + // - (bag.count "x" (bag "y" 5)) = 0 + // - (bag.count "x" (bag "x" 4)) = 4 + // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4 + // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0 std::map<Node, Rational> elements = getBagElements(n[1]); std::map<Node, Rational>::iterator it = elements.find(n[0]); @@ -260,22 +262,23 @@ Node NormalForm::evaluateBagCount(TNode n) NodeManager* nm = NodeManager::currentNM(); if (it != elements.end()) { - Node count = nm->mkConst(it->second); + Node count = nm->mkConst(CONST_RATIONAL, it->second); return count; } - return nm->mkConst(Rational(0)); + return nm->mkConst(CONST_RATIONAL, Rational(0)); } Node NormalForm::evaluateDuplicateRemoval(TNode n) { - Assert(n.getKind() == DUPLICATE_REMOVAL); + Assert(n.getKind() == BAG_DUPLICATE_REMOVAL); // Examples // -------- - // - (duplicate_removal (emptybag String)) = (emptybag String) - // - (duplicate_removal (mkBag "x" 4)) = (emptybag "x" 1) - // - (duplicate_removal (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = - // (disjoint_union (mkBag "x" 1) (mkBag "y" 1) + // - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag + // String)) + // - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1) + // - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) = + // (bag.disjoint_union (bag "x" 1) (bag "y" 1) std::map<Node, Rational> oldElements = getBagElements(n[0]); // copy elements from the old bag @@ -292,16 +295,16 @@ Node NormalForm::evaluateDuplicateRemoval(TNode n) Node NormalForm::evaluateUnionDisjoint(TNode n) { - Assert(n.getKind() == UNION_DISJOINT); + Assert(n.getKind() == BAG_UNION_DISJOINT); // Example // ------- - // input: (union_disjoint A B) - // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) - // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // input: (bag.union_disjoint A B) + // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) + // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) // output: - // (union_disjoint A B) - // where A = (MK_BAG "x" 7) - // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) + // (bag.union_disjoint A B) + // where A = (bag "x" 7) + // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2))) auto equal = [](std::map<Node, Rational>& elements, std::map<Node, Rational>::const_iterator& itA, @@ -352,16 +355,16 @@ Node NormalForm::evaluateUnionDisjoint(TNode n) Node NormalForm::evaluateUnionMax(TNode n) { - Assert(n.getKind() == UNION_MAX); + Assert(n.getKind() == BAG_UNION_MAX); // Example // ------- - // input: (union_max A B) - // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) - // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // input: (bag.union_max A B) + // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) + // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) // output: - // (union_disjoint A B) - // where A = (MK_BAG "x" 4) - // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) + // (bag.union_disjoint A B) + // where A = (bag "x" 4) + // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2))) auto equal = [](std::map<Node, Rational>& elements, std::map<Node, Rational>::const_iterator& itA, @@ -412,14 +415,14 @@ Node NormalForm::evaluateUnionMax(TNode n) Node NormalForm::evaluateIntersectionMin(TNode n) { - Assert(n.getKind() == INTERSECTION_MIN); + Assert(n.getKind() == BAG_INTER_MIN); // Example // ------- - // input: (intersectionMin A B) - // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) - // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // input: (bag.inter_min A B) + // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) + // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) // output: - // (MK_BAG "x" 3) + // (bag "x" 3) auto equal = [](std::map<Node, Rational>& elements, std::map<Node, Rational>::const_iterator& itA, @@ -458,14 +461,14 @@ Node NormalForm::evaluateIntersectionMin(TNode n) Node NormalForm::evaluateDifferenceSubtract(TNode n) { - Assert(n.getKind() == DIFFERENCE_SUBTRACT); + Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT); // Example // ------- - // input: (difference_subtract A B) - // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) - // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // input: (bag.difference_subtract A B) + // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) + // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) // output: - // (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2)) + // (bag.union_disjoint (bag "x" 1) (bag "z" 2)) auto equal = [](std::map<Node, Rational>& elements, std::map<Node, Rational>::const_iterator& itA, @@ -510,14 +513,14 @@ Node NormalForm::evaluateDifferenceSubtract(TNode n) Node NormalForm::evaluateDifferenceRemove(TNode n) { - Assert(n.getKind() == DIFFERENCE_REMOVE); + Assert(n.getKind() == BAG_DIFFERENCE_REMOVE); // Example // ------- - // input: (difference_subtract A B) - // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) - // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // input: (bag.difference_remove A B) + // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) + // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) // output: - // (MK_BAG "z" 2) + // (bag "z" 2) auto equal = [](std::map<Node, Rational>& elements, std::map<Node, Rational>::const_iterator& itA, @@ -564,9 +567,9 @@ Node NormalForm::evaluateChoose(TNode n) Assert(n.getKind() == BAG_CHOOSE); // Examples // -------- - // - (bag.choose (MK_BAG "x" 4)) = "x" + // - (bag.choose (bag "x" 4)) = "x" - if (n[0].getKind() == MK_BAG) + if (n[0].getKind() == BAG_MAKE) { return n[0][0]; } @@ -578,9 +581,9 @@ Node NormalForm::evaluateCard(TNode n) Assert(n.getKind() == BAG_CARD); // Examples // -------- - // - (card (emptyBag String)) = 0 - // - (choose (MK_BAG "x" 4)) = 4 - // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5 + // - (card (as bag.empty (Bag String))) = 0 + // - (bag.choose (bag "x" 4)) = 4 + // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5 std::map<Node, Rational> elements = getBagElements(n[0]); Rational sum(0); @@ -590,7 +593,7 @@ Node NormalForm::evaluateCard(TNode n) } NodeManager* nm = NodeManager::currentNM(); - Node sumNode = nm->mkConst(sum); + Node sumNode = nm->mkConst(CONST_RATIONAL, sum); return sumNode; } @@ -599,12 +602,13 @@ Node NormalForm::evaluateIsSingleton(TNode n) Assert(n.getKind() == BAG_IS_SINGLETON); // Examples // -------- - // - (bag.is_singleton (emptyBag String)) = false - // - (bag.is_singleton (MK_BAG "x" 1)) = true - // - (bag.is_singleton (MK_BAG "x" 4)) = false - // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false + // - (bag.is_singleton (as bag.empty (Bag String))) = false + // - (bag.is_singleton (bag "x" 1)) = true + // - (bag.is_singleton (bag "x" 4)) = false + // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1))) + // = false - if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne()) + if (n[0].getKind() == BAG_MAKE && n[0][1].getConst<Rational>().isOne()) { return NodeManager::currentNM()->mkConst(true); } @@ -617,10 +621,10 @@ Node NormalForm::evaluateFromSet(TNode n) // Examples // -------- - // - (bag.from_set (set.empty String)) = (emptybag String) - // - (bag.from_set (singleton "x")) = (mkBag "x" 1) - // - (bag.from_set (union (singleton "x") (singleton "y"))) = - // (disjoint_union (mkBag "x" 1) (mkBag "y" 1)) + // - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String)) + // - (bag.from_set (set.singleton "x")) = (bag "x" 1) + // - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) = + // (bag.disjoint_union (bag "x" 1) (bag "y" 1)) NodeManager* nm = NodeManager::currentNM(); std::set<Node> setElements = @@ -642,10 +646,10 @@ Node NormalForm::evaluateToSet(TNode n) // Examples // -------- - // - (bag.to_set (emptybag String)) = (set.empty String) - // - (bag.to_set (mkBag "x" 4)) = (singleton "x") - // - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = - // (union (singleton "x") (singleton "y"))) + // - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String)) + // - (bag.to_set (bag "x" 4)) = (set.singleton "x") + // - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) = + // (set.union (set.singleton "x") (set.singleton "y"))) NodeManager* nm = NodeManager::currentNM(); std::map<Node, Rational> bagElements = getBagElements(n[0]); @@ -667,8 +671,8 @@ Node NormalForm::evaluateBagMap(TNode n) // Examples // -------- // - (bag.map ((lambda ((x String)) "z") - // (union_disjoint (bag "a" 2) (bag "b" 3)) = - // (union_disjoint + // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) = + // (bag.union_disjoint // (bag ((lambda ((x String)) "z") "a") 2) // (bag ((lambda ((x String)) "z") "b") 3)) = // (bag "z" 5) |