diff options
Diffstat (limited to 'src/theory/bv/slice_manager.h')
-rw-r--r-- | src/theory/bv/slice_manager.h | 384 |
1 files changed, 316 insertions, 68 deletions
diff --git a/src/theory/bv/slice_manager.h b/src/theory/bv/slice_manager.h index 8fc1e0b9d..96a0067dc 100644 --- a/src/theory/bv/slice_manager.h +++ b/src/theory/bv/slice_manager.h @@ -13,6 +13,7 @@ #include "theory/bv/cd_set_collection.h" #include <map> +#include <set> #include <vector> namespace CVC4 { @@ -134,9 +135,6 @@ private: /** The equality engine */ EqualityEngine& d_equalityEngine; - /** The id of the concatenation function */ - size_t d_concatFunctionId; - /** The collection of backtrackable sets */ set_collection d_setCollection; @@ -153,27 +151,35 @@ public: SliceManager(TheoryBitvector& theoryBitvector, context::Context* context) : d_theoryBitvector(theoryBitvector), d_equalityEngine(theoryBitvector.getEqualityEngine()), d_setCollection(context) { - // We register the concatentation with the equality engine - d_concatFunctionId = d_equalityEngine.newFunction("bv_concat", true); } - inline size_t getConcatFunctionId() const { return d_concatFunctionId; } - /** - * Adds the equality (lhs = rhs) to the slice manager. This will not add the equalities to the equality manager, - * but will slice the equality according to the current slicing in order to align all the slices. The terms that - * get slices get sent to the theory engine as equalities, i.e if we slice x[10:0] into x[10:5]@x[4:0] equality - * engine gets the assertion x[10:0] = concat(x[10:5], x[4:0]). + * Adds the equality (lhs = rhs) to the slice manager. The equality is first normalized according to the equality + * manager, i.e. each base term is taken from the equality manager, replaced in, and then the whole concatenation + * normalized and sliced wrt the current slicing. The method will not add the equalities to the equality manager, + * but instead will slice the equality according to the current slicing in order to align all the slices. + * + * The terms that get sliced get sent to the theory engine as equalities, i.e if we slice x[10:0] into x[10:5]@x[4:0] + * equality engine gets the assertion x[10:0] = concat(x[10:5], x[4:0]). + * + * input output slicing + * -------------------------------------------------------------------------------------------------------------- + * x@y = y@x x = y, y = x empty + * x[31:0]@x[64:32] = x x = x[31:0]@x[63:32] x:{64,32,0} + * x@y = 0000@x@0000 x = 0000@x[7:4], y = x[3:0]@0000 x:{8,4,0} + * */ - inline void addEquality(TNode lhs, TNode rhs, std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices); + inline bool solveEquality(TNode lhs, TNode rhs); private: + inline bool solveEquality(TNode lhs, TNode rhs, const std::set<TNode>& assumptions); + /** - * Slices up lhs and rhs and returns the slices in lhsSlices and rhsSlices + * Slices up lhs and rhs and returns the slices in lhsSlices and rhsSlices. The slices are not atomic, + * they are sliced in order to make one of lhs or rhs atomic, the other one can be a concatenation. */ - inline void slice(std::vector<Node>& lhs, std::vector<Node>& rhs, - std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices); + inline bool sliceAndSolve(std::vector<Node>& lhs, std::vector<Node>& rhs, const std::set<TNode>& assumptions); /** * Returns true if the term is already sliced wrt the current slicing. Note that, for example, even though @@ -184,7 +190,7 @@ private: /** * Slices the term wrt the current slicing. When done, isSliced returns true */ - inline void slice(TNode node, std::vector<Node>& sliced); + inline bool slice(TNode node, std::vector<Node>& sliced); /** * Returns the base term in the core theory of the given term, i.e. @@ -196,20 +202,87 @@ private: static inline TNode baseTerm(TNode node); /** - * Adds a new slice to the slice set of the given base term. + * Adds a new slice to the slice set of the given term. */ - inline void addSlice(Node baseTerm, unsigned slicePoint); + inline bool addSlice(Node term, unsigned slicePoint); }; template <class TheoryBitvector> -void SliceManager<TheoryBitvector>::addSlice(Node baseTerm, unsigned slicePoint) { +bool SliceManager<TheoryBitvector>::addSlice(Node node, unsigned slicePoint) { + Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << ")" << std::endl; + + bool ok = true; + + int low = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractLow(node) : 0; + int high = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractHigh(node) + 1: utils::getSize(node); + slicePoint += low; + + TNode nodeBase = baseTerm(node); + + set_reference sliceSet; + slicing_map::iterator find = d_nodeSlicing.find(nodeBase); + if (find == d_nodeSlicing.end()) { + sliceSet = d_nodeSlicing[nodeBase] = d_setCollection.newSet(slicePoint); + d_setCollection.insert(sliceSet, low); + d_setCollection.insert(sliceSet, high); + } else { + sliceSet = find->second; + } + + // What are the points surrounding the new slice point + int prev = d_setCollection.prev(sliceSet, slicePoint); + int next = d_setCollection.next(sliceSet, slicePoint); + + // Add the slice to the set + d_setCollection.insert(sliceSet, slicePoint); + Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << "): current set " << d_setCollection.toString(sliceSet) << std::endl; + + // Add the terms and the equality to the equality engine + Node t1 = utils::mkExtract(nodeBase, next - 1, slicePoint); + Node t2 = utils::mkExtract(nodeBase, slicePoint - 1, prev); + Node nodeSlice = (next == high && prev == low) ? node : utils::mkExtract(nodeBase, next - 1, prev); + Node concat = utils::mkConcat(t1, t2); + + d_equalityEngine.addTerm(t1); + d_equalityEngine.addTerm(t2); + d_equalityEngine.addTerm(nodeSlice); + d_equalityEngine.addTerm(concat); + + // We are free to add this slice, unless the slice has a representative that's already a concat + TNode nodeSliceRepresentative = d_equalityEngine.getRepresentative(nodeSlice); + if (nodeSliceRepresentative.getKind() != kind::BITVECTOR_CONCAT) { + // Add the slice to the equality engine + ok = d_equalityEngine.addEquality(nodeSlice, concat, utils::mkTrue()); + } else { + // If the representative is a concat, we must solve it + // There is no need do add nodeSlice = concat as we will solve the representative of nodeSlice + std::set<TNode> assumptions; + std::vector<TNode> equalities; + d_equalityEngine.getExplanation(nodeSlice, nodeSliceRepresentative, equalities); + assumptions.insert(equalities.begin(), equalities.end()); + ok = solveEquality(nodeSliceRepresentative, concat, assumptions); + } + + Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << ") => " << d_setCollection.toString(d_nodeSlicing[nodeBase]) << std::endl; + + return ok; } template <class TheoryBitvector> -void SliceManager<TheoryBitvector>::addEquality(TNode lhs, TNode rhs, std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices) { +bool SliceManager<TheoryBitvector>::solveEquality(TNode lhs, TNode rhs) { + std::set<TNode> assumptions; + assumptions.insert(lhs.eqNode(rhs)); + bool ok = solveEquality(lhs, rhs, assumptions); + return ok; +} + +template <class TheoryBitvector> +bool SliceManager<TheoryBitvector>::solveEquality(TNode lhs, TNode rhs, const std::set<TNode>& assumptions) { + + Debug("slicing") << "SliceMagager::solveEquality(" << lhs << "," << rhs << "," << utils::setToString(assumptions) << ")" << push << std::endl; - Debug("slicing") << "SliceMagager::addEquality(" << lhs << "," << rhs << ")" << std::endl; + bool ok; // The concatenations on the left-hand side (reverse order, first is on top) std::vector<Node> lhsTerms; @@ -232,60 +305,213 @@ void SliceManager<TheoryBitvector>::addEquality(TNode lhs, TNode rhs, std::vecto } // Slice the individual terms to align them - slice(lhsTerms, rhsTerms, lhsSlices, rhsSlices); + ok = sliceAndSolve(lhsTerms, rhsTerms, assumptions); + + Debug("slicing") << "SliceMagager::solveEquality(" << lhs << "," << rhs << "," << utils::setToString(assumptions) << ")" << pop << std::endl; + + return ok; } + template <class TheoryBitvector> -void SliceManager<TheoryBitvector>::slice(std::vector<Node>& lhs, std::vector<Node>& rhs, - std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices) { +bool SliceManager<TheoryBitvector>::sliceAndSolve(std::vector<Node>& lhs, std::vector<Node>& rhs, const std::set<TNode>& assumptions) +{ - Debug("slicing") << "SliceManager::slice()" << std::endl; + Debug("slicing") << "SliceManager::sliceAndSolve()" << std::endl; - // Go through the work-list and align + // Go through the work-list, solve and align while (!lhs.empty()) { Assert(!rhs.empty()); + Debug("slicing") << "SliceManager::sliceAndSolve(): lhs " << utils::vectorToString(lhs) << std::endl; + Debug("slicing") << "SliceManager::sliceAndSolve(): rhs " << utils::vectorToString(rhs) << std::endl; + // The terms that we need to slice Node lhsTerm = lhs.back(); Node rhsTerm = rhs.back(); - Debug("slicing") << "slicing: " << lhsTerm << " and " << rhsTerm << std::endl; + + Debug("slicing") << "SliceManager::sliceAndSolve(): " << lhsTerm << " : " << rhsTerm << std::endl; // If the terms are not sliced wrt the current slicing, we have them sliced lhs.pop_back(); if (!isSliced(lhsTerm)) { - slice(lhsTerm, lhs); + if (!slice(lhsTerm, lhs)) return false; + Debug("slicing") << "SliceManager::sliceAndSolve(): lhs sliced" << std::endl; continue; } rhs.pop_back(); if (!isSliced(rhsTerm)) { - slice(rhsTerm, rhs); + if (!slice(rhsTerm, rhs)) return false; + // We also need to put lhs back + lhs.push_back(lhsTerm); + Debug("slicing") << "SliceManager::sliceAndSolve(): rhs sliced" << std::endl; + continue; } + Debug("slicing") << "SliceManager::sliceAndSolve(): both lhs and rhs sliced already" << std::endl; + + // The solving concatenation + std::vector<Node> concatTerms; + // If the slices are of the same size we do the additional work - unsigned lhsSize = utils::getSize(lhsTerm); - unsigned rhsSize = utils::getSize(rhsTerm); - if (lhsSize == rhsSize) { - // If they are over the same base terms, we need to do something - TNode lhsBaseTerm = baseTerm(lhsTerm); - TNode rhsBaseTerm = baseTerm(rhsTerm); - if (lhsBaseTerm == rhsBaseTerm) { - // x[i_1:j_1] vs x[i_2:j_2] + int sizeDifference = utils::getSize(lhsTerm) - utils::getSize(rhsTerm); + + // We slice constants immediately + if (sizeDifference > 0 && lhsTerm.getKind() == kind::CONST_BITVECTOR) { + BitVector low = lhsTerm.getConst<BitVector>().extract(utils::getSize(rhsTerm) - 1, 0); + BitVector high = lhsTerm.getConst<BitVector>().extract(utils::getSize(lhsTerm) - 1, utils::getSize(rhsTerm)); + lhs.push_back(utils::mkConst(low)); + lhs.push_back(utils::mkConst(high)); + rhs.push_back(rhsTerm); + continue; + } + if (sizeDifference < 0 && rhsTerm.getKind() == kind::CONST_BITVECTOR) { + BitVector low = rhsTerm.getConst<BitVector>().extract(utils::getSize(lhsTerm) - 1, 0); + BitVector high = rhsTerm.getConst<BitVector>().extract(utils::getSize(rhsTerm) - 1, utils::getSize(lhsTerm)); + rhs.push_back(utils::mkConst(low)); + rhs.push_back(utils::mkConst(high)); + lhs.push_back(lhsTerm); + continue; + } + + enum SolvingFor { + SOLVING_FOR_LHS, + SOLVING_FOR_RHS + } solvingFor = sizeDifference < 0 || lhsTerm.getKind() == kind::CONST_BITVECTOR ? SOLVING_FOR_RHS : SOLVING_FOR_LHS; + + Debug("slicing") << "SliceManager::sliceAndSolve(): " << (solvingFor == SOLVING_FOR_LHS ? "solving for LHS" : "solving for RHS") << std::endl; + + // When we slice in order to align, we might have to reslice the one we are solving for + bool reslice = false; + + switch (solvingFor) { + case SOLVING_FOR_RHS: { + concatTerms.push_back(lhsTerm); + // Maybe we need to add more lhs to make them equal + while (sizeDifference < 0 && !reslice) { + Assert(lhs.size() > 0); + // Get the next part for lhs + lhsTerm = lhs.back(); + lhs.pop_back(); + // Slice if necessary + if (!isSliced(lhsTerm)) { + if (!slice(lhsTerm, lhs)) return false; + continue; + } + // If we go above 0, we need to cut it + if (sizeDifference + (int)utils::getSize(lhsTerm) > 0) { + // Slice it so it fits + addSlice(lhsTerm, (int)utils::getSize(lhsTerm) + sizeDifference); + if (!slice(lhsTerm, lhs)) return false; + if (!isSliced(rhsTerm)) { + if (!slice(rhsTerm, rhs)) return false; + while(!concatTerms.empty()) { + lhs.push_back(concatTerms.back()); + concatTerms.pop_back(); + } + reslice = true; + } + continue; + } + concatTerms.push_back(lhsTerm); + sizeDifference += utils::getSize(lhsTerm); + } + break; + } + case SOLVING_FOR_LHS: { + concatTerms.push_back(rhsTerm); + // Maybe we need to add more rhs to make them equal + while (sizeDifference > 0 && !reslice) { + Assert(rhs.size() > 0); + // Get the next part for lhs + rhsTerm = rhs.back(); + rhs.pop_back(); + // Slice if necessary + if (!isSliced(rhsTerm)) { + if (!slice(rhsTerm, rhs)) return false; + continue; + } + // If we go below 0, we need to cut it + if (sizeDifference - (int)utils::getSize(rhsTerm) < 0) { + // Slice it so it fits + addSlice(rhsTerm, (int)utils::getSize(rhsTerm) - sizeDifference); + if (!slice(rhsTerm, rhs)) return false; + if (!isSliced(lhsTerm)) { + if (!slice(lhsTerm, lhs)) return false; + while(!concatTerms.empty()) { + rhs.push_back(concatTerms.back()); + concatTerms.pop_back(); + } + reslice = true; + } + continue; + } + concatTerms.push_back(rhsTerm); + sizeDifference -= utils::getSize(rhsTerm); + } + break; + } + } + + // If we need to reslice + if (reslice) { + continue; + } + + Assert(sizeDifference == 0); + + Node concat = utils::mkConcat(concatTerms); + Debug("slicing") << "SliceManager::sliceAndSolve(): concatenation " << concat << std::endl; + + // We have them equal size now. If the base term of the one we are solving is solved into a + // non-trivial concatenation already, we have to normalize. A concatenation is non-trivial if + // it is not a direct slicing, i.e it is a concat, and normalize(x) != x + switch (solvingFor) { + case SOLVING_FOR_LHS: { + TNode lhsTermRepresentative = d_equalityEngine.getRepresentative(lhsTerm); + if (lhsTermRepresentative != lhsTerm && + (lhsTermRepresentative.getKind() == kind::BITVECTOR_CONCAT || lhsTermRepresentative.getKind() == kind::CONST_BITVECTOR)) { + // We need to normalize and solve the normalized equations + std::vector<TNode> explanation; + d_equalityEngine.getExplanation(lhsTerm, lhsTermRepresentative, explanation); + std::set<TNode> additionalAssumptions(assumptions); + additionalAssumptions.insert(explanation.begin(), explanation.end()); + bool ok = solveEquality(lhsTermRepresentative, concat, additionalAssumptions); + if (!ok) return false; } else { - // x[i_1:j_1] vs y[i_2:j_2] + // We're fine, just add the equality + Debug("slicing") << "SliceManager::sliceAndSolve(): adding " << lhsTerm << " = " << concat << " " << utils::setToString(assumptions) << std::endl; + d_equalityEngine.addTerm(concat); + bool ok = d_equalityEngine.addEquality(lhsTerm, concat, utils::mkConjunction(assumptions)); + if (!ok) return false; } - lhsSlices.push_back(lhsTerm); - rhsSlices.push_back(rhsTerm); - continue; - } else { - // They are not of equal sizes, so we slice one - if (lhsSize < rhsSize) { - // We need to cut a piece of rhs + break; + } + case SOLVING_FOR_RHS: { + TNode rhsTermRepresentative = d_equalityEngine.getRepresentative(rhsTerm); + if (rhsTermRepresentative != rhsTerm && + (rhsTermRepresentative.getKind() == kind::BITVECTOR_CONCAT || rhsTermRepresentative.getKind() == kind::CONST_BITVECTOR)) { + // We need to normalize and solve the normalized equations + std::vector<TNode> explanation; + d_equalityEngine.getExplanation(rhsTerm, rhsTermRepresentative, explanation); + std::set<TNode> additionalAssumptions(assumptions); + additionalAssumptions.insert(explanation.begin(), explanation.end()); + bool ok = solveEquality(rhsTermRepresentative, concat, additionalAssumptions); + if (!ok) return false; } else { - // We need to cut a piece of lhs + // We're fine, just add the equality + Debug("slicing") << "SliceManager::sliceAndSolve(): adding " << rhsTerm << " = " << concat << utils::setToString(assumptions) << std::endl; + d_equalityEngine.addTerm(concat); + bool ok = d_equalityEngine.addEquality(rhsTerm, concat, utils::mkConjunction(assumptions)); + if (!ok) return false; } + break; + } } } + + return true; } template <class TheoryBitvector> @@ -315,9 +541,11 @@ bool SliceManager<TheoryBitvector>::isSliced(TNode node) const { if (find == d_nodeSlicing.end()) { result = nodeKind != kind::BITVECTOR_EXTRACT; } else { - // Check whether there is a slice point in [high, low), if there is the term is not sliced. - // Hence, if we look for the upper bound of low, and it is higher than high, it is sliced. - result = d_setCollection.count(find->second, low + 1, high) > 0; + // The term is not sliced if one of the borders is not in the slice set or + // there is a point between the borders + result = + d_setCollection.contains(find->second, low) && d_setCollection.contains(find->second, high + 1) && + (low == high || d_setCollection.count(find->second, low + 1, high) == 0); } } @@ -326,7 +554,7 @@ bool SliceManager<TheoryBitvector>::isSliced(TNode node) const { } template <class TheoryBitvector> -inline void SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>& sliced) { +inline bool SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>& sliced) { Debug("slicing") << "SliceManager::slice(" << node << ")" << std::endl; @@ -335,44 +563,64 @@ inline void SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>& // The indices of the beginning and (one past) end unsigned high = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractHigh(node) + 1 : utils::getSize(node); unsigned low = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractLow(node) : 0; + Debug("slicing") << "SliceManager::slice(" << node << "): low: " << low << std::endl; + Debug("slicing") << "SliceManager::slice(" << node << "): high: " << high << std::endl; // Get the base term TNode nodeBase = baseTerm(node); Assert(nodeBase.getKind() != kind::BITVECTOR_CONCAT); Assert(nodeBase.getKind() != kind::CONST_BITVECTOR); - // Get the base term slice set - set_collection::reference_type nodeSliceSet = d_nodeSlicing[nodeBase]; + // The nodes slice set + set_collection::reference_type nodeSliceSet; + + // Find the current one or construct it + slicing_map::iterator findSliceSet = d_nodeSlicing.find(nodeBase); + if (findSliceSet == d_nodeSlicing.end()) { + nodeSliceSet = d_setCollection.newSet(utils::getSize(nodeBase)); + d_setCollection.insert(nodeSliceSet, 0); + d_nodeSlicing[nodeBase] = nodeSliceSet; + } else { + nodeSliceSet = d_nodeSlicing[nodeBase]; + } + Debug("slicing") << "SliceManager::slice(" << node << "): current: " << d_setCollection.toString(nodeSliceSet) << std::endl; std::vector<size_t> slicePoints; - d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints); - + if (low + 1 < high) { + d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints); + } + // Go through all the points i_0 <= low < i_1 < ... < i_{n-1} < high <= i_n from the slice set // and generate the slices [i_0:low-1][low:i_1-1] [i_1:i2] ... [i_{n-1}:high-1][high:i_n-1]. They are in reverse order, // as they should be size_t i_0 = low == 0 ? 0 : d_setCollection.prev(nodeSliceSet, low + 1); - size_t i_n = high == utils::getSize(nodeBase) ? high: d_setCollection.next(nodeSliceSet, high); + Debug("slicing") << "SliceManager::slice(" << node << "): i_0: " << i_0 << std::endl; + size_t i_n = high == utils::getSize(nodeBase) ? high: d_setCollection.next(nodeSliceSet, high - 1); + Debug("slicing") << "SliceManager::slice(" << node << "): i_n: " << i_n << std::endl; // Add the new points to the slice set (they might be there already) if (high < i_n) { - std::vector<Node> lastTwoSlices; - lastTwoSlices.push_back(utils::mkExtract(nodeBase, i_n-1, high)); - lastTwoSlices.push_back(utils::mkExtract(nodeBase, high-1, slicePoints.back())); - d_equalityEngine.addEquality(utils::mkExtract(nodeBase, i_n-1, slicePoints.back()), utils::mkConcat(lastTwoSlices)); + if (!addSlice(nodeBase, high)) return false; } - - while (!slicePoints.empty()) { + // Construct the actuall slicing + if (slicePoints.size() > 0) { + Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, slicePoints[0] - 1, low) << std::endl; + sliced.push_back(utils::mkExtract(nodeBase, slicePoints[0] - 1, low)); + for (unsigned i = 1; i < slicePoints.size(); ++ i) { + Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, slicePoints[i] - 1, slicePoints[i-1])<< std::endl; + sliced.push_back(utils::mkExtract(nodeBase, slicePoints[i] - 1, slicePoints[i-1])); + } + Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, high-1, slicePoints.back()) << std::endl; sliced.push_back(utils::mkExtract(nodeBase, high-1, slicePoints.back())); - high = slicePoints.back(); - slicePoints.pop_back(); + } else { + sliced.push_back(utils::mkExtract(nodeBase, high - 1, low)); } - + // Add the new points to the slice set (they might be there already) if (i_0 < low) { - std::vector<Node> firstTwoSlices; - firstTwoSlices.push_back(utils::mkExtract(nodeBase, high-1, low)); - firstTwoSlices.push_back(utils::mkExtract(nodeBase, low-1, i_0)); - d_equalityEngine.addEquality(utils::mkExtract(nodeBase, high-1, i_0), utils::mkConcat(firstTwoSlices)); + if (!addSlice(nodeBase, low)) return false; } + + return true; } template <class TheoryBitvector> |