diff options
author | mudathirmahgoub <mudathirmahgoub@gmail.com> | 2020-10-21 17:33:57 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-21 17:33:57 -0500 |
commit | 31983bd41f8c6ec736e374946de355fd1a9bc6f1 (patch) | |
tree | 2d93b3b482809ba40b6d5e5f8174f1c3f7b72f02 /src/theory/bags | |
parent | 3c68378f6a87d96b2baadb35988777c06f54727b (diff) |
Implement bags evaluator (#5322)
This PR implements NormalForm::evaluate for bags
Diffstat (limited to 'src/theory/bags')
-rw-r--r-- | src/theory/bags/bags_rewriter.cpp | 2 | ||||
-rw-r--r-- | src/theory/bags/normal_form.cpp | 604 | ||||
-rw-r--r-- | src/theory/bags/normal_form.h | 147 | ||||
-rw-r--r-- | src/theory/bags/theory_bags_type_rules.h | 2 |
4 files changed, 738 insertions, 17 deletions
diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index c413a5e7e..26c54d4ec 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -51,7 +51,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) // no need to rewrite n if it is already in a normal form response = BagsRewriteResponse(n, Rewrite::NONE); } - else if (NormalForm::AreChildrenConstants(n)) + else if (NormalForm::areChildrenConstants(n)) { Node value = NormalForm::evaluate(n); response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION); diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp index facad3c92..f2dea62d3 100644 --- a/src/theory/bags/normal_form.cpp +++ b/src/theory/bags/normal_form.cpp @@ -12,26 +12,620 @@ #include "normal_form.h" +#include "theory/sets/normal_form.h" +#include "theory/type_enumerator.h" + +using namespace CVC4::kind; + namespace CVC4 { namespace theory { namespace bags { -bool NormalForm::checkNormalConstant(TNode n) +bool NormalForm::isConstant(TNode n) { - // TODO(projects#223): complete this function + if (n.getKind() == EMPTYBAG) + { + // empty bags are already normalized + return true; + } + if (n.getKind() == MK_BAG) + { + // see the implementation in MkBagTypeRule::computeIsConst + return n.isConst(); + } + if (n.getKind() == UNION_DISJOINT) + { + if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst())) + { + // the first child is not a constant + return false; + } + // store the previous element to check the ordering of elements + Node previousElement = n[0][0]; + Node current = n[1]; + while (current.getKind() == UNION_DISJOINT) + { + if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst())) + { + // the current element is not a constant + return false; + } + if (previousElement >= current[0][0]) + { + // the ordering is violated + return false; + } + previousElement = current[0][0]; + current = current[1]; + } + // check last element + if (!(current.getKind() == kind::MK_BAG && current.isConst())) + { + // the last element is not a constant + return false; + } + if (previousElement >= current[0]) + { + // the ordering is violated + return false; + } + return true; + } + + // only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be + // constants return false; } -bool NormalForm::AreChildrenConstants(TNode n) +bool NormalForm::areChildrenConstants(TNode n) { return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); } Node NormalForm::evaluate(TNode n) { - // TODO(projects#223): complete this function - return CVC4::Node(); + Assert(areChildrenConstants(n)); + if (n.isConst()) + { + // a constant node is already in a normal form + return n; + } + switch (n.getKind()) + { + case MK_BAG: return evaluateMakeBag(n); + case BAG_COUNT: return evaluateBagCount(n); + case UNION_DISJOINT: return evaluateUnionDisjoint(n); + case UNION_MAX: return evaluateUnionMax(n); + case INTERSECTION_MIN: return evaluateIntersectionMin(n); + case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n); + case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n); + case BAG_CHOOSE: return evaluateChoose(n); + case BAG_CARD: return evaluateCard(n); + case BAG_IS_SINGLETON: return evaluateIsSingleton(n); + case BAG_FROM_SET: return evaluateFromSet(n); + case BAG_TO_SET: return evaluateToSet(n); + default: break; + } + Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n + << std::endl; +} + +template <typename T1, typename T2, typename T3, typename T4, typename T5> +Node NormalForm::evaluateBinaryOperation(const TNode& n, + T1&& equal, + T2&& less, + T3&& greaterOrEqual, + T4&& remainderOfA, + T5&& remainderOfB) +{ + std::map<Node, Rational> elementsA = getBagElements(n[0]); + std::map<Node, Rational> elementsB = getBagElements(n[1]); + std::map<Node, Rational> elements; + + std::map<Node, Rational>::const_iterator itA = elementsA.begin(); + std::map<Node, Rational>::const_iterator itB = elementsB.begin(); + + Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation " + << n.getKind() << "] " << std::endl + << "elements A: " << elementsA << std::endl + << "elements B: " << elementsB << std::endl; + + while (itA != elementsA.end() && itB != elementsB.end()) + { + if (itA->first == itB->first) + { + equal(elements, itA, itB); + itA++; + itB++; + } + else if (itA->first < itB->first) + { + less(elements, itA, itB); + itA++; + } + else + { + greaterOrEqual(elements, itA, itB); + itB++; + } + } + + // handle the remaining elements from A + remainderOfA(elements, elementsA, itA); + // handle the remaining elements from B + remainderOfA(elements, elementsB, itB); + + Trace("bags-evaluate") << "elements: " << elements << std::endl; + Node bag = constructBagFromElements(n.getType(), elements); + Trace("bags-evaluate") << "bag: " << bag << std::endl; + return bag; +} + +std::map<Node, Rational> NormalForm::getBagElements(TNode n) +{ + Assert(n.isConst()) << "node " << n << " is not in a normal form" + << std::endl; + std::map<Node, Rational> elements; + if (n.getKind() == EMPTYBAG) + { + return elements; + } + while (n.getKind() == kind::UNION_DISJOINT) + { + Assert(n[0].getKind() == kind::MK_BAG); + Node element = n[0][0]; + Rational count = n[0][1].getConst<Rational>(); + elements[element] = count; + n = n[1]; + } + Assert(n.getKind() == kind::MK_BAG); + Node lastElement = n[0]; + Rational lastCount = n[1].getConst<Rational>(); + elements[lastElement] = lastCount; + return elements; +} + +Node NormalForm::constructBagFromElements( + TypeNode t, const std::map<Node, Rational>& elements) +{ + Assert(t.isBag()); + NodeManager* nm = NodeManager::currentNM(); + if (elements.empty()) + { + return nm->mkConst(EmptyBag(t)); + } + TypeNode elementType = t.getBagElementType(); + std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin(); + Node bag = + nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second)); + while (++it != elements.rend()) + { + Node n = + nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second)); + bag = nm->mkNode(UNION_DISJOINT, n, bag); + } + return bag; +} + +Node NormalForm::evaluateMakeBag(TNode n) +{ + // the case where n is const should be handled earlier. + // here we handle the case where the multiplicity is zero or negative + Assert(n.getKind() == MK_BAG && !n.isConst() + && n[1].getConst<Rational>().sgn() < 1); + Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType())); + return emptybag; +} + +Node NormalForm::evaluateBagCount(TNode n) +{ + Assert(n.getKind() == BAG_COUNT); + // Examples + // -------- + // - (bag.count "x" (emptybag String)) = 0 + // - (bag.count "x" (mkBag "y" 5)) = 0 + // - (bag.count "x" (mkBag "x" 4)) = 4 + // - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4 + // - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0 + + std::map<Node, Rational> elements = getBagElements(n[1]); + std::map<Node, Rational>::iterator it = elements.find(n[0]); + + NodeManager* nm = NodeManager::currentNM(); + if (it != elements.end()) + { + Node count = nm->mkConst(it->second); + return count; + } + return nm->mkConst(Rational(0)); +} + +Node NormalForm::evaluateUnionDisjoint(TNode n) +{ + Assert(n.getKind() == UNION_DISJOINT); + // Example + // ------- + // input: (union_disjoint A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (union_disjoint A B) + // where A = (MK_BAG "x" 7) + // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) + + auto equal = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // compute the sum of the multiplicities + elements[itA->first] = itA->second + itB->second; + }; + + auto less = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // add the element to the result + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // add the element to the result + elements[itB->first] = itB->second; + }; + + auto remainderOfA = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>& elementsA, + std::map<Node, Rational>::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>& elementsB, + std::map<Node, Rational>::const_iterator& itB) { + // append the remainder of B + while (itB != elementsB.end()) + { + elements[itB->first] = itB->second; + itB++; + } + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node NormalForm::evaluateUnionMax(TNode n) +{ + Assert(n.getKind() == UNION_MAX); + // Example + // ------- + // input: (union_max A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (union_disjoint A B) + // where A = (MK_BAG "x" 4) + // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) + + auto equal = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // compute the maximum multiplicity + elements[itA->first] = std::max(itA->second, itB->second); + }; + + auto less = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // add to the result + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // add to the result + elements[itB->first] = itB->second; + }; + + auto remainderOfA = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>& elementsA, + std::map<Node, Rational>::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>& elementsB, + std::map<Node, Rational>::const_iterator& itB) { + // append the remainder of B + while (itB != elementsB.end()) + { + elements[itB->first] = itB->second; + itB++; + } + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } + +Node NormalForm::evaluateIntersectionMin(TNode n) +{ + Assert(n.getKind() == INTERSECTION_MIN); + // Example + // ------- + // input: (intersectionMin A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (MK_BAG "x" 3) + + auto equal = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // compute the minimum multiplicity + elements[itA->first] = std::min(itA->second, itB->second); + }; + + auto less = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // do nothing + }; + + auto greaterOrEqual = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // do nothing + }; + + auto remainderOfA = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>& elementsA, + std::map<Node, Rational>::const_iterator& itA) { + // do nothing + }; + + auto remainderOfB = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>& elementsB, + std::map<Node, Rational>::const_iterator& itB) { + // do nothing + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node NormalForm::evaluateDifferenceSubtract(TNode n) +{ + Assert(n.getKind() == DIFFERENCE_SUBTRACT); + // Example + // ------- + // input: (difference_subtract A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2)) + + auto equal = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // subtract the multiplicities + elements[itA->first] = itA->second - itB->second; + }; + + auto less = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // itA->first is not in B, so we add it to the difference subtract + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // itB->first is not in A, so we just skip it + }; + + auto remainderOfA = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>& elementsA, + std::map<Node, Rational>::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>& elementsB, + std::map<Node, Rational>::const_iterator& itB) { + // do nothing + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node NormalForm::evaluateDifferenceRemove(TNode n) +{ + Assert(n.getKind() == DIFFERENCE_REMOVE); + // Example + // ------- + // input: (difference_subtract A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (MK_BAG "z" 2) + + auto equal = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // skip the shared element by doing nothing + }; + + auto less = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // itA->first is not in B, so we add it to the difference remove + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>::const_iterator& itA, + std::map<Node, Rational>::const_iterator& itB) { + // itB->first is not in A, so we just skip it + }; + + auto remainderOfA = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>& elementsA, + std::map<Node, Rational>::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map<Node, Rational>& elements, + std::map<Node, Rational>& elementsB, + std::map<Node, Rational>::const_iterator& itB) { + // do nothing + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node NormalForm::evaluateChoose(TNode n) +{ + Assert(n.getKind() == BAG_CHOOSE); + // Examples + // -------- + // - (choose (emptyBag String)) = "" // the empty string which is the first + // element returned by the type enumerator + // - (choose (MK_BAG "x" 4)) = "x" + // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x" + // deterministically return the first element + + if (n[0].getKind() == EMPTYBAG) + { + TypeNode elementType = n[0].getType().getBagElementType(); + TypeEnumerator typeEnumerator(elementType); + // get the first value from the typeEnumerator + Node element = *typeEnumerator; + return element; + } + + if (n[0].getKind() == MK_BAG) + { + return n[0][0]; + } + Assert(n[0].getKind() == UNION_DISJOINT); + // return the first element + // e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) + return n[0][0][0]; +} + +Node NormalForm::evaluateCard(TNode n) +{ + Assert(n.getKind() == BAG_CARD); + // Examples + // -------- + // - (card (emptyBag String)) = 0 + // - (choose (MK_BAG "x" 4)) = 4 + // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5 + + std::map<Node, Rational> elements = getBagElements(n[0]); + Rational sum(0); + for (std::pair<Node, Rational> element : elements) + { + sum += element.second; + } + + NodeManager* nm = NodeManager::currentNM(); + Node sumNode = nm->mkConst(sum); + return sumNode; +} + +Node NormalForm::evaluateIsSingleton(TNode n) +{ + Assert(n.getKind() == BAG_IS_SINGLETON); + // Examples + // -------- + // - (bag.is_singleton (emptyBag String)) = false + // - (bag.is_singleton (MK_BAG "x" 1)) = true + // - (bag.is_singleton (MK_BAG "x" 4)) = false + // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false + + if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne()) + { + return NodeManager::currentNM()->mkConst(true); + } + return NodeManager::currentNM()->mkConst(false); +} + +Node NormalForm::evaluateFromSet(TNode n) +{ + Assert(n.getKind() == BAG_FROM_SET); + + // Examples + // -------- + // - (bag.from_set (emptyset String)) = (emptybag String) + // - (bag.from_set (singleton "x")) = (mkBag "x" 1) + // - (bag.from_set (union (singleton "x") (singleton "y"))) = + // (disjoint_union (mkBag "x" 1) (mkBag "y" 1)) + + NodeManager* nm = NodeManager::currentNM(); + std::set<Node> setElements = + sets::NormalForm::getElementsFromNormalConstant(n[0]); + Rational one = Rational(1); + std::map<Node, Rational> bagElements; + for (const Node& element : setElements) + { + bagElements[element] = one; + } + TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType()); + Node bag = constructBagFromElements(bagType, bagElements); + return bag; +} + +Node NormalForm::evaluateToSet(TNode n) +{ + Assert(n.getKind() == BAG_TO_SET); + + // Examples + // -------- + // - (bag.to_set (emptybag String)) = (emptyset String) + // - (bag.to_set (mkBag "x" 4)) = (singleton "x") + // - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = + // (union (singleton "x") (singleton "y"))) + + NodeManager* nm = NodeManager::currentNM(); + std::map<Node, Rational> bagElements = getBagElements(n[0]); + std::set<Node> setElements; + std::map<Node, Rational>::const_reverse_iterator it; + for (it = bagElements.rbegin(); it != bagElements.rend(); it++) + { + setElements.insert(it->first); + } + TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType()); + Node set = sets::NormalForm::elementsToSet(setElements, setType); + return set; +} + } // namespace bags } // namespace theory } // namespace CVC4
\ No newline at end of file diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/normal_form.h index 8c719fe81..ef0edefff 100644 --- a/src/theory/bags/normal_form.h +++ b/src/theory/bags/normal_form.h @@ -29,22 +29,149 @@ class NormalForm /** * Returns true if n is considered a to be a (canonical) constant bag value. * A canonical bag value is one whose AST is: - * (disjoint-union (mk-bag e1 n1) ... - * (disjoint-union (mk-bag e_{n-1} n_{n-1}) (mk-bag e_n n_n)))) - * where c1 ... cn are constants and the node identifier of these constants - * are such that: - * c1 > ... > cn. - * Also handles the corner cases of empty bag and singleton bag. + * (union_disjoint (mkBag e1 c1) ... + * (union_disjoint (mkBag e_{n-1} c_{n-1}) (mkBag e_n c_n)))) + * where c1 ... cn are positive integers, e1 ... en are constants, and the + * node identifier of these constants are such that: e1 < ... < en. + * Also handles the corner cases of empty bag and bag constructed by mkBag */ - static bool checkNormalConstant(TNode n); + static bool isConstant(TNode n); /** - * check whether all children of the given node are in normal form + * check whether all children of the given node are constants */ - static bool AreChildrenConstants(TNode n); + static bool areChildrenConstants(TNode n); /** - * evaluate the node n to a constant value + * evaluate the node n to a constant value. + * As a precondition, children of n should be constants. */ static Node evaluate(TNode n); + + /** + * get the elements along with their multiplicities in a given bag + * @param n a constant node whose type is a bag + * @return a map whose keys are constant elements and values are + * multiplicities + */ + static std::map<Node, Rational> getBagElements(TNode n); + + /** + * construct a constant bag from constant elements + * @param t the type of the returned bag + * @param elements a map whose keys are constant elements and values are + * multiplicities + * @return a constant bag that contains + */ + static Node constructBagFromElements( + TypeNode t, const std::map<Node, Rational>& elements); + + private: + /** + * a high order helper function that return a constant bag that is the result + * of (op A B) where op is a binary operator and A, B are constant bags. + * The result is computed from the elements of A (elementsA with iterator itA) + * and elements of B (elementsB with iterator itB). + * The arguments below specify how these iterators are used to generate the + * elements of the result (elements). + * @param n a node whose kind is a binary operator (union_disjoint, union_max, + * intersection_min, difference_subtract, difference_remove) and whose + * children are constant bags. + * @param equal a lambda expression that receives (elements, itA, itB) and + * specify the action that needs to be taken when the elements of itA, itB are + * equal. + * @param less a lambda expression that receives (elements, itA, itB) and + * specify the action that needs to be taken when the element itA is less than + * the element of itB. + * @param greaterOrEqual less a lambda expression that receives (elements, + * itA, itB) and specify the action that needs to be taken when the element + * itA is greater than or equal than the element of itB. + * @param remainderOfA a lambda expression that receives (elements, elementsA, + * itA) and specify the action that needs to be taken to the remaining + * elements of A when all elements of B are visited. + * @param remainderOfB a lambda expression that receives (elements, elementsB, + * itB) and specify the action that needs to be taken to the remaining + * elements of B when all elements of A are visited. + * @return a constant bag that the result of (op n[0] n[1]) + */ + template <typename T1, typename T2, typename T3, typename T4, typename T5> + static Node evaluateBinaryOperation(const TNode& n, + T1&& equal, + T2&& less, + T3&& greaterOrEqual, + T4&& remainderOfA, + T5&& remainderOfB); + /** + * evaluate n as follows: + * - (mkBag a 0) = (emptybag T) where T is the type of the original bag + * - (mkBag a (-c)) = (emptybag T) where T is the type the original bag, + * and c > 0 is a constant + */ + static Node evaluateMakeBag(TNode n); + + /** + * returns the multiplicity in a constant bag + * @param n has the form (bag.count x A) where x, A are constants + * @return the multiplicity of element x in bag A. + */ + static Node evaluateBagCount(TNode n); + + /** + * evaluates union disjoint node such that the returned node is a canonical + * bag that has the form + * (union_disjoint (mkBag e1 c1) ... + * (union_disjoint * (mkBag e_{n-1} c_{n-1}) (mkBag e_n c_n)))) where + * c1... cn are positive integers, e1 ... en are constants, and the node + * identifier of these constants are such that: e1 < ... < en. + * @param n has the form (union_disjoint A B) where A, B are constant bags + * @return the union disjoint of A and B + */ + static Node evaluateUnionDisjoint(TNode n); + /** + * @param n has the form (union_max A B) where A, B are constant bags + * @return the union max of A and B + */ + static Node evaluateUnionMax(TNode n); + /** + * @param n has the form (intersection_min A B) where A, B are constant bags + * @return the intersection min of A and B + */ + static Node evaluateIntersectionMin(TNode n); + /** + * @param n has the form (difference_subtract A B) where A, B are constant + * bags + * @return the difference subtract of A and B + */ + static Node evaluateDifferenceSubtract(TNode n); + /** + * @param n has the form (difference_remove A B) where A, B are constant bags + * @return the difference remove of A and B + */ + static Node evaluateDifferenceRemove(TNode n); + /** + * @param n has the form (bag.choose A) where A is a constant bag + * @return the first element of A if A is not empty. Otherwise, it returns the + * first element returned by the type enumerator for the elements + */ + static Node evaluateChoose(TNode n); + /** + * @param n has the form (bag.card A) where A is a constant bag + * @return the number of elements in bag A + */ + static Node evaluateCard(TNode n); + /** + * @param n has the form (bag.is_singleton A) where A is a constant bag + * @return whether the bag A has cardinality one. + */ + static Node evaluateIsSingleton(TNode n); + /** + * @param n has the form (bag.from_set A) where A is a constant set + * @return a constant bag that contains exactly the elements in A. + */ + static Node evaluateFromSet(TNode n); + /** + * @param n has the form (bag.to_set A) where A is a constant bag + * @return a constant set constructed from the elements in A. + */ + static Node evaluateToSet(TNode n); }; } // namespace bags } // namespace theory diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index 75f57ec88..7767938ed 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -57,7 +57,7 @@ struct BinaryOperatorTypeRule // only UNION_DISJOINT has a const rule in kinds. // Other binary operators do not have const rules in kinds Assert(n.getKind() == kind::UNION_DISJOINT); - return NormalForm::checkNormalConstant(n); + return NormalForm::isConstant(n); } }; /* struct BinaryOperatorTypeRule */ |