/********************* */ /*! \file bag_solver.cpp ** \verbatim ** Top contributors (to current version): ** Mudathir Mohamed, Andrew Reynolds ** This file is part of the CVC4 project. ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS ** in the top-level source directory and their institutional affiliations. ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim ** ** \brief solver for the theory of bags. ** ** solver for the theory of bags. **/ #include "theory/bags/bag_solver.h" #include "theory/bags/inference_generator.h" #include "theory/bags/inference_manager.h" #include "theory/bags/normal_form.h" #include "theory/bags/solver_state.h" #include "theory/bags/term_registry.h" #include "theory/uf/equality_engine_iterator.h" using namespace std; using namespace cvc5::context; using namespace cvc5::kind; namespace cvc5 { namespace theory { namespace bags { BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr) : d_state(s), d_ig(&s, &im), d_im(im), d_termReg(tr) { d_zero = NodeManager::currentNM()->mkConst(Rational(0)); d_one = NodeManager::currentNM()->mkConst(Rational(1)); d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); } BagSolver::~BagSolver() {} void BagSolver::postCheck() { d_state.initialize(); checkDisequalBagTerms(); // At this point, all bag and count representatives should be in the solver // state. for (const Node& bag : d_state.getBags()) { // iterate through all bags terms in each equivalent class eq::EqClassIterator it = eq::EqClassIterator(bag, d_state.getEqualityEngine()); while (!it.isFinished()) { Node n = (*it); Kind k = n.getKind(); switch (k) { case kind::EMPTYBAG: checkEmpty(n); break; case kind::MK_BAG: checkMkBag(n); break; case kind::UNION_DISJOINT: checkUnionDisjoint(n); break; case kind::UNION_MAX: checkUnionMax(n); break; case kind::INTERSECTION_MIN: checkIntersectionMin(n); break; case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break; case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break; case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break; default: break; } it++; } } // add non negative constraints for all multiplicities for (const Node& n : d_state.getBags()) { for (const Node& e : d_state.getElements(n)) { checkNonNegativeCountTerms(n, e); } } } set BagSolver::getElementsForBinaryOperator(const Node& n) { set elements; const set& downwards = d_state.getElements(n); const set& upwards0 = d_state.getElements(n[0]); const set& upwards1 = d_state.getElements(n[1]); set_union(downwards.begin(), downwards.end(), upwards0.begin(), upwards0.end(), inserter(elements, elements.begin())); elements.insert(upwards1.begin(), upwards1.end()); return elements; } void BagSolver::checkEmpty(const Node& n) { Assert(n.getKind() == EMPTYBAG); for (const Node& e : d_state.getElements(n)) { InferInfo i = d_ig.empty(n, e); d_im.lemmaTheoryInference(&i); } } void BagSolver::checkUnionDisjoint(const Node& n) { Assert(n.getKind() == UNION_DISJOINT); std::set elements = getElementsForBinaryOperator(n); for (const Node& e : elements) { InferInfo i = d_ig.unionDisjoint(n, e); d_im.lemmaTheoryInference(&i); } } void BagSolver::checkUnionMax(const Node& n) { Assert(n.getKind() == UNION_MAX); std::set elements = getElementsForBinaryOperator(n); for (const Node& e : elements) { InferInfo i = d_ig.unionMax(n, e); d_im.lemmaTheoryInference(&i); } } void BagSolver::checkIntersectionMin(const Node& n) { Assert(n.getKind() == INTERSECTION_MIN); std::set elements = getElementsForBinaryOperator(n); for (const Node& e : elements) { InferInfo i = d_ig.intersection(n, e); d_im.lemmaTheoryInference(&i); } } void BagSolver::checkDifferenceSubtract(const Node& n) { Assert(n.getKind() == DIFFERENCE_SUBTRACT); std::set elements = getElementsForBinaryOperator(n); for (const Node& e : elements) { InferInfo i = d_ig.differenceSubtract(n, e); d_im.lemmaTheoryInference(&i); } } void BagSolver::checkMkBag(const Node& n) { Assert(n.getKind() == MK_BAG); Trace("bags::BagSolver::postCheck") << "BagSolver::checkMkBag Elements of " << n << " are: " << d_state.getElements(n) << std::endl; for (const Node& e : d_state.getElements(n)) { InferInfo i = d_ig.mkBag(n, e); d_im.lemmaTheoryInference(&i); } } void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element) { InferInfo i = d_ig.nonNegativeCount(bag, element); d_im.lemmaTheoryInference(&i); } void BagSolver::checkDifferenceRemove(const Node& n) { Assert(n.getKind() == DIFFERENCE_REMOVE); std::set elements = getElementsForBinaryOperator(n); for (const Node& e : elements) { InferInfo i = d_ig.differenceRemove(n, e); d_im.lemmaTheoryInference(&i); } } void BagSolver::checkDuplicateRemoval(Node n) { Assert(n.getKind() == DUPLICATE_REMOVAL); set elements; const set& downwards = d_state.getElements(n); const set& upwards = d_state.getElements(n[0]); elements.insert(downwards.begin(), downwards.end()); elements.insert(upwards.begin(), upwards.end()); for (const Node& e : elements) { InferInfo i = d_ig.duplicateRemoval(n, e); d_im.lemmaTheoryInference(&i); } } void BagSolver::checkDisequalBagTerms() { for (const Node& n : d_state.getDisequalBagTerms()) { InferInfo info = d_ig.bagDisequality(n); d_im.lemmaTheoryInference(&info); } } } // namespace bags } // namespace theory } // namespace cvc5