/****************************************************************************** * Top contributors (to current version): * Mudathir Mohamed, Aina Niemetz * * This file is part of the cvc5 project. * * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS * in the top-level source directory and their institutional affiliations. * All rights reserved. See the file COPYING in the top-level source * directory for licensing information. * **************************************************************************** * * Normal form for bag constants. */ #include "normal_form.h" #include "expr/emptybag.h" #include "smt/logic_exception.h" #include "theory/sets/normal_form.h" #include "theory/type_enumerator.h" #include "util/rational.h" using namespace cvc5::kind; namespace cvc5 { namespace theory { namespace bags { bool NormalForm::isConstant(TNode n) { if (n.getKind() == BAG_EMPTY) { // empty bags are already normalized return true; } if (n.getKind() == BAG_MAKE) { // see the implementation in MkBagTypeRule::computeIsConst return n.isConst(); } if (n.getKind() == BAG_UNION_DISJOINT) { if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst())) { // the first child is not a constant return false; } // store the previous element to check the ordering of elements Node previousElement = n[0][0]; Node current = n[1]; while (current.getKind() == BAG_UNION_DISJOINT) { if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst())) { // the current element is not a constant return false; } if (previousElement >= current[0][0]) { // the ordering is violated return false; } previousElement = current[0][0]; current = current[1]; } // check last element if (!(current.getKind() == kind::BAG_MAKE && current.isConst())) { // the last element is not a constant return false; } if (previousElement >= current[0]) { // the ordering is violated return false; } return true; } // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be // constants return false; } bool NormalForm::areChildrenConstants(TNode n) { return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); } Node NormalForm::evaluate(TNode n) { Assert(areChildrenConstants(n)); if (n.isConst()) { // a constant node is already in a normal form return n; } switch (n.getKind()) { case BAG_MAKE: return evaluateMakeBag(n); case BAG_COUNT: return evaluateBagCount(n); case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n); case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n); case BAG_UNION_MAX: return evaluateUnionMax(n); case BAG_INTER_MIN: return evaluateIntersectionMin(n); case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n); case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n); case BAG_CARD: return evaluateCard(n); case BAG_IS_SINGLETON: return evaluateIsSingleton(n); case BAG_FROM_SET: return evaluateFromSet(n); case BAG_TO_SET: return evaluateToSet(n); case BAG_MAP: return evaluateBagMap(n); default: break; } Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n << std::endl; } template Node NormalForm::evaluateBinaryOperation(const TNode& n, T1&& equal, T2&& less, T3&& greaterOrEqual, T4&& remainderOfA, T5&& remainderOfB) { std::map elementsA = getBagElements(n[0]); std::map elementsB = getBagElements(n[1]); std::map elements; std::map::const_iterator itA = elementsA.begin(); std::map::const_iterator itB = elementsB.begin(); Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation " << n.getKind() << "] " << std::endl << "elements A: " << elementsA << std::endl << "elements B: " << elementsB << std::endl; while (itA != elementsA.end() && itB != elementsB.end()) { if (itA->first == itB->first) { equal(elements, itA, itB); itA++; itB++; } else if (itA->first < itB->first) { less(elements, itA, itB); itA++; } else { greaterOrEqual(elements, itA, itB); itB++; } } // handle the remaining elements from A remainderOfA(elements, elementsA, itA); // handle the remaining elements from B remainderOfB(elements, elementsB, itB); Trace("bags-evaluate") << "elements: " << elements << std::endl; Node bag = constructConstantBagFromElements(n.getType(), elements); Trace("bags-evaluate") << "bag: " << bag << std::endl; return bag; } std::map NormalForm::getBagElements(TNode n) { Assert(n.isConst()) << "node " << n << " is not in a normal form" << std::endl; std::map elements; if (n.getKind() == BAG_EMPTY) { return elements; } while (n.getKind() == kind::BAG_UNION_DISJOINT) { Assert(n[0].getKind() == kind::BAG_MAKE); Node element = n[0][0]; Rational count = n[0][1].getConst(); elements[element] = count; n = n[1]; } Assert(n.getKind() == kind::BAG_MAKE); Node lastElement = n[0]; Rational lastCount = n[1].getConst(); elements[lastElement] = lastCount; return elements; } Node NormalForm::constructConstantBagFromElements( TypeNode t, const std::map& elements) { Assert(t.isBag()); NodeManager* nm = NodeManager::currentNM(); if (elements.empty()) { return nm->mkConst(EmptyBag(t)); } TypeNode elementType = t.getBagElementType(); std::map::const_reverse_iterator it = elements.rbegin(); Node bag = nm->mkBag(elementType, it->first, nm->mkConst(CONST_RATIONAL, it->second)); while (++it != elements.rend()) { Node n = nm->mkBag(elementType, it->first, nm->mkConst(CONST_RATIONAL, it->second)); bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag); } return bag; } Node NormalForm::constructBagFromElements(TypeNode t, const std::map& elements) { Assert(t.isBag()); NodeManager* nm = NodeManager::currentNM(); if (elements.empty()) { return nm->mkConst(EmptyBag(t)); } TypeNode elementType = t.getBagElementType(); std::map::const_reverse_iterator it = elements.rbegin(); Node bag = nm->mkBag(elementType, it->first, it->second); while (++it != elements.rend()) { Node n = nm->mkBag(elementType, it->first, it->second); bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag); } return bag; } Node NormalForm::evaluateMakeBag(TNode n) { // the case where n is const should be handled earlier. // here we handle the case where the multiplicity is zero or negative Assert(n.getKind() == BAG_MAKE && !n.isConst() && n[1].getConst().sgn() < 1); Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType())); return emptybag; } Node NormalForm::evaluateBagCount(TNode n) { Assert(n.getKind() == BAG_COUNT); // Examples // -------- // - (bag.count "x" (as bag.empty (Bag String))) = 0 // - (bag.count "x" (bag "y" 5)) = 0 // - (bag.count "x" (bag "x" 4)) = 4 // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4 // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0 std::map elements = getBagElements(n[1]); std::map::iterator it = elements.find(n[0]); NodeManager* nm = NodeManager::currentNM(); if (it != elements.end()) { Node count = nm->mkConst(CONST_RATIONAL, it->second); return count; } return nm->mkConst(CONST_RATIONAL, Rational(0)); } Node NormalForm::evaluateDuplicateRemoval(TNode n) { Assert(n.getKind() == BAG_DUPLICATE_REMOVAL); // Examples // -------- // - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag // String)) // - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1) // - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) = // (bag.disjoint_union (bag "x" 1) (bag "y" 1) std::map oldElements = getBagElements(n[0]); // copy elements from the old bag std::map newElements(oldElements); Rational one = Rational(1); std::map::iterator it; for (it = newElements.begin(); it != newElements.end(); it++) { it->second = one; } Node bag = constructConstantBagFromElements(n[0].getType(), newElements); return bag; } Node NormalForm::evaluateUnionDisjoint(TNode n) { Assert(n.getKind() == BAG_UNION_DISJOINT); // Example // ------- // input: (bag.union_disjoint A B) // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) // output: // (bag.union_disjoint A B) // where A = (bag "x" 7) // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2))) auto equal = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // compute the sum of the multiplicities elements[itA->first] = itA->second + itB->second; }; auto less = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // add the element to the result elements[itA->first] = itA->second; }; auto greaterOrEqual = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // add the element to the result elements[itB->first] = itB->second; }; auto remainderOfA = [](std::map& elements, std::map& elementsA, std::map::const_iterator& itA) { // append the remainder of A while (itA != elementsA.end()) { elements[itA->first] = itA->second; itA++; } }; auto remainderOfB = [](std::map& elements, std::map& elementsB, std::map::const_iterator& itB) { // append the remainder of B while (itB != elementsB.end()) { elements[itB->first] = itB->second; itB++; } }; return evaluateBinaryOperation( n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } Node NormalForm::evaluateUnionMax(TNode n) { Assert(n.getKind() == BAG_UNION_MAX); // Example // ------- // input: (bag.union_max A B) // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) // output: // (bag.union_disjoint A B) // where A = (bag "x" 4) // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2))) auto equal = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // compute the maximum multiplicity elements[itA->first] = std::max(itA->second, itB->second); }; auto less = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // add to the result elements[itA->first] = itA->second; }; auto greaterOrEqual = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // add to the result elements[itB->first] = itB->second; }; auto remainderOfA = [](std::map& elements, std::map& elementsA, std::map::const_iterator& itA) { // append the remainder of A while (itA != elementsA.end()) { elements[itA->first] = itA->second; itA++; } }; auto remainderOfB = [](std::map& elements, std::map& elementsB, std::map::const_iterator& itB) { // append the remainder of B while (itB != elementsB.end()) { elements[itB->first] = itB->second; itB++; } }; return evaluateBinaryOperation( n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } Node NormalForm::evaluateIntersectionMin(TNode n) { Assert(n.getKind() == BAG_INTER_MIN); // Example // ------- // input: (bag.inter_min A B) // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) // output: // (bag "x" 3) auto equal = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // compute the minimum multiplicity elements[itA->first] = std::min(itA->second, itB->second); }; auto less = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // do nothing }; auto greaterOrEqual = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // do nothing }; auto remainderOfA = [](std::map& elements, std::map& elementsA, std::map::const_iterator& itA) { // do nothing }; auto remainderOfB = [](std::map& elements, std::map& elementsB, std::map::const_iterator& itB) { // do nothing }; return evaluateBinaryOperation( n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } Node NormalForm::evaluateDifferenceSubtract(TNode n) { Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT); // Example // ------- // input: (bag.difference_subtract A B) // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) // output: // (bag.union_disjoint (bag "x" 1) (bag "z" 2)) auto equal = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // subtract the multiplicities elements[itA->first] = itA->second - itB->second; }; auto less = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // itA->first is not in B, so we add it to the difference subtract elements[itA->first] = itA->second; }; auto greaterOrEqual = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // itB->first is not in A, so we just skip it }; auto remainderOfA = [](std::map& elements, std::map& elementsA, std::map::const_iterator& itA) { // append the remainder of A while (itA != elementsA.end()) { elements[itA->first] = itA->second; itA++; } }; auto remainderOfB = [](std::map& elements, std::map& elementsB, std::map::const_iterator& itB) { // do nothing }; return evaluateBinaryOperation( n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } Node NormalForm::evaluateDifferenceRemove(TNode n) { Assert(n.getKind() == BAG_DIFFERENCE_REMOVE); // Example // ------- // input: (bag.difference_remove A B) // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) // output: // (bag "z" 2) auto equal = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // skip the shared element by doing nothing }; auto less = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // itA->first is not in B, so we add it to the difference remove elements[itA->first] = itA->second; }; auto greaterOrEqual = [](std::map& elements, std::map::const_iterator& itA, std::map::const_iterator& itB) { // itB->first is not in A, so we just skip it }; auto remainderOfA = [](std::map& elements, std::map& elementsA, std::map::const_iterator& itA) { // append the remainder of A while (itA != elementsA.end()) { elements[itA->first] = itA->second; itA++; } }; auto remainderOfB = [](std::map& elements, std::map& elementsB, std::map::const_iterator& itB) { // do nothing }; return evaluateBinaryOperation( n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } Node NormalForm::evaluateChoose(TNode n) { Assert(n.getKind() == BAG_CHOOSE); // Examples // -------- // - (bag.choose (bag "x" 4)) = "x" if (n[0].getKind() == BAG_MAKE) { return n[0][0]; } throw LogicException("BAG_CHOOSE_TOTAL is not supported yet"); } Node NormalForm::evaluateCard(TNode n) { Assert(n.getKind() == BAG_CARD); // Examples // -------- // - (card (as bag.empty (Bag String))) = 0 // - (bag.choose (bag "x" 4)) = 4 // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5 std::map elements = getBagElements(n[0]); Rational sum(0); for (std::pair element : elements) { sum += element.second; } NodeManager* nm = NodeManager::currentNM(); Node sumNode = nm->mkConst(CONST_RATIONAL, sum); return sumNode; } Node NormalForm::evaluateIsSingleton(TNode n) { Assert(n.getKind() == BAG_IS_SINGLETON); // Examples // -------- // - (bag.is_singleton (as bag.empty (Bag String))) = false // - (bag.is_singleton (bag "x" 1)) = true // - (bag.is_singleton (bag "x" 4)) = false // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1))) // = false if (n[0].getKind() == BAG_MAKE && n[0][1].getConst().isOne()) { return NodeManager::currentNM()->mkConst(true); } return NodeManager::currentNM()->mkConst(false); } Node NormalForm::evaluateFromSet(TNode n) { Assert(n.getKind() == BAG_FROM_SET); // Examples // -------- // - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String)) // - (bag.from_set (set.singleton "x")) = (bag "x" 1) // - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) = // (bag.disjoint_union (bag "x" 1) (bag "y" 1)) NodeManager* nm = NodeManager::currentNM(); std::set setElements = sets::NormalForm::getElementsFromNormalConstant(n[0]); Rational one = Rational(1); std::map bagElements; for (const Node& element : setElements) { bagElements[element] = one; } TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType()); Node bag = constructConstantBagFromElements(bagType, bagElements); return bag; } Node NormalForm::evaluateToSet(TNode n) { Assert(n.getKind() == BAG_TO_SET); // Examples // -------- // - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String)) // - (bag.to_set (bag "x" 4)) = (set.singleton "x") // - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) = // (set.union (set.singleton "x") (set.singleton "y"))) NodeManager* nm = NodeManager::currentNM(); std::map bagElements = getBagElements(n[0]); std::set setElements; std::map::const_reverse_iterator it; for (it = bagElements.rbegin(); it != bagElements.rend(); it++) { setElements.insert(it->first); } TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType()); Node set = sets::NormalForm::elementsToSet(setElements, setType); return set; } Node NormalForm::evaluateBagMap(TNode n) { Assert(n.getKind() == BAG_MAP); // Examples // -------- // - (bag.map ((lambda ((x String)) "z") // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) = // (bag.union_disjoint // (bag ((lambda ((x String)) "z") "a") 2) // (bag ((lambda ((x String)) "z") "b") 3)) = // (bag "z" 5) std::map elements = NormalForm::getBagElements(n[1]); std::map mappedElements; std::map::iterator it = elements.begin(); NodeManager* nm = NodeManager::currentNM(); while (it != elements.end()) { Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first); mappedElements[mappedElement] = it->second; ++it; } TypeNode t = nm->mkBagType(n[0].getType().getRangeType()); Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements); return ret; } } // namespace bags } // namespace theory } // namespace cvc5