diff options
author | lianah <lianahady@gmail.com> | 2013-03-16 15:48:51 -0400 |
---|---|---|
committer | lianah <lianahady@gmail.com> | 2013-03-16 15:48:51 -0400 |
commit | 25ac2c8f4b45e2b299895e97a30790fbf46cf79f (patch) | |
tree | d7b52003d7157073be554bd9818230f1c3b439d3 /src/theory/bv/slicer.cpp | |
parent | 3fcdb18fe92e5213aa708285c0d7d5e55633492b (diff) |
started work on the inequality bv subtheory
Diffstat (limited to 'src/theory/bv/slicer.cpp')
-rw-r--r-- | src/theory/bv/slicer.cpp | 250 |
1 files changed, 208 insertions, 42 deletions
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp index ac668ab20..92166224b 100644 --- a/src/theory/bv/slicer.cpp +++ b/src/theory/bv/slicer.cpp @@ -19,7 +19,7 @@ #include "theory/bv/slicer.h" #include "theory/bv/theory_bv_utils.h" #include "theory/rewriter.h" - +#include "theory/bv/bv_subtheory_core.h" using namespace CVC4; using namespace CVC4::theory; using namespace CVC4::theory::bv; @@ -159,7 +159,7 @@ std::string NormalForm::debugPrint(const UnionFind& uf) const { std::string UnionFind::Node::debugPrint() const { ostringstream os; - os << "Repr " << d_repr << " ["<< d_bitwidth << "] "; + os << "Repr " << d_edge.repr << " ["<< d_bitwidth << "] "; os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl; return os.str(); } @@ -169,27 +169,44 @@ std::string UnionFind::Node::debugPrint() const { * UnionFind * */ -TermId UnionFind::addTerm(Index bitwidth) { +TermId UnionFind::addNode(Index bitwidth) { + Assert (bitwidth > 0); Node node(bitwidth); d_nodes.push_back(node); + ++(d_statistics.d_numNodes); TermId id = d_nodes.size() - 1; // d_representatives.insert(id); ++(d_statistics.d_numRepresentatives); - Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl; return id; } + +TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) { + ExtractTerm extract(topLevel, high, low); + if (d_extractToId.find(extract) != d_extractToId.end()) { + return d_extractToId[extract]; + } + + Assert (high >= low); + + TermId id = addNode(high - low + 1); + d_idToExtract[id] = extract; + d_extractToId[extract] = id; + return id; +} + /** * At this point we assume the slicings of the two terms are properly aligned. * * @param t1 * @param t2 */ -void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2) { +void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason) { Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n" - << " " << t2.debugPrint() << endl; + << " " << t2.debugPrint() << "\n" + << " with reason " << reason << endl; Assert (t1.getBitwidth() == t2.getBitwidth()); NormalForm nf1(t1.getBitwidth()); @@ -202,10 +219,11 @@ void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2) { Assert (nf1.base == nf2.base); for (unsigned i = 0; i < nf1.decomp.size(); ++i) { - merge (nf1.decomp[i], nf2.decomp[i]); + merge (nf1.decomp[i], nf2.decomp[i], reason); } } + /** * Merge the two terms in the union find. Both t1 and t2 * should be root terms. @@ -213,7 +231,7 @@ void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2) { * @param t1 * @param t2 */ -void UnionFind::merge(TermId t1, TermId t2) { +void UnionFind::merge(TermId t1, TermId t2, TermId reason) { Debug("bv-slicer-uf") << "UnionFind::merge (" << t1 <<", " << t2 << ")" << endl; ++(d_statistics.d_numMerges); t1 = find(t1); @@ -223,7 +241,7 @@ void UnionFind::merge(TermId t1, TermId t2) { return; Assert (! hasChildren(t1) && ! hasChildren(t2)); - setRepr(t1, t2); + setRepr(t1, t2, reason); recordOperation(UnionFind::MERGE, t1); //d_representatives.erase(t1); d_statistics.d_numRepresentatives += -1; @@ -233,11 +251,26 @@ TermId UnionFind::find(TermId id) { TermId repr = getRepr(id); if (repr != UndefinedId) { TermId find_id = find(repr); - // setRepr(id, find_id); return find_id; } return id; } + +TermId UnionFind::findWithExplanation(TermId id, std::vector<ExplanationId>& explanation) { + TermId repr = getRepr(id); + + if (repr != UndefinedId) { + TermId reason = getReason(id); + Assert (reason != UndefinedId); + explanation.push_back(reason); + + TermId find_id = findWithExplanation(repr, explanation); + return find_id; + } + return id; +} + + /** * Splits the representative of the term between i-1 and i * @@ -259,10 +292,14 @@ void UnionFind::split(TermId id, Index i) { Assert (i < getBitwidth(id)); if (!hasChildren(id)) { // first time we split this term - TermId bottom_id = addTerm(i); - TermId top_id = addTerm(getBitwidth(id) - i); + TermId bottom_id = addExtract(getTopLevel(id), i - 1, 0); + TermId top_id = addExtract(getTopLevel(id), getBitwidth(id) - 1, i); setChildren(id, top_id, bottom_id); - recordOperation(UnionFind::SPLIT, id); + recordOperation(UnionFind::SPLIT, id); + + if (d_slicer->termInEqualityEngine(id)) { + d_slicer->enqueueSplit(id, i); + } } else { Index cut = getCutPoint(id); if (i < cut ) @@ -273,6 +310,14 @@ void UnionFind::split(TermId id, Index i) { ++(d_statistics.d_numSplits); } +TermId UnionFind::getTopLevel(TermId id) const { + __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> >::const_iterator it = d_idToExtract.find(id); + if (it != d_idToExtract.end()) { + return (*it).second.id; + } + return id; +} + void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) { nf.clear(); getDecomposition(term, nf.decomp); @@ -319,6 +364,56 @@ void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp) getDecomposition(high_child, decomp); } } + +void UnionFind::getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, + std::vector<ExplanationId>& explanation) { + nf.clear(); + getDecompositionWithExplanation(term, nf.decomp, explanation); + // update nf base + Index count = 0; + for (unsigned i = 0; i < nf.decomp.size(); ++i) { + count += getBitwidth(nf.decomp[i]); + nf.base.sliceAt(count); + } + Debug("bv-slicer-uf") << "UnionFind::getNormalFrom term: " << term.debugPrint() << endl; + Debug("bv-slicer-uf") << " nf: " << nf.debugPrint(*this) << endl; +} + +void UnionFind::getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, + std::vector<ExplanationId>& explanation) { + // making sure the term is aligned + TermId id = findWithExplanation(term.id, explanation); + + Assert (term.high < getBitwidth(id)); + // because we split the node, this must be the whole extract + if (!hasChildren(id)) { + Assert (term.high == getBitwidth(id) - 1 && + term.low == 0); + decomp.push_back(id); + return; + } + + Index cut = getCutPoint(id); + + if (term.low < cut && term.high < cut) { + // the extract falls entirely on the low child + ExtractTerm child_ex(getChild(id, 0), term.high, term.low); + getDecompositionWithExplanation(child_ex, decomp, explanation); + } + else if (term.low >= cut && term.high >= cut){ + // the extract falls entirely on the high child + ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut); + getDecompositionWithExplanation(child_ex, decomp, explanation); + } + else { + // the extract is split over the two children + ExtractTerm low_child(getChild(id, 0), cut - 1, term.low); + getDecompositionWithExplanation(low_child, decomp, explanation); + ExtractTerm high_child(getChild(id, 1), term.high - cut, 0); + getDecompositionWithExplanation(high_child, decomp, explanation); + } +} + /** * May cause reslicings of the decompositions. Must not assume the decompositons * are the current normal form. @@ -428,7 +523,7 @@ void UnionFind::ensureSlicing(const ExtractTerm& term) { void UnionFind::backtrack() { return; int size = d_undoStack.size(); - for (int i = size; i > d_undoStackIndex.get(); --i) { + for (int i = size; i > (int)d_undoStackIndex.get(); --i) { Operation op = d_undoStack.back(); Assert (!d_undoStack.empty()); d_undoStack.pop_back(); @@ -443,8 +538,8 @@ void UnionFind::backtrack() { void UnionFind::undoMerge(TermId id) { TermId repr = getRepr(id); - Assert (repr != id); - setRepr(id, UndefinedId); + Assert (repr != UndefinedId); + setRepr(id, UndefinedId, UndefinedId); } void UnionFind::undoSplit(TermId id) { @@ -453,9 +548,6 @@ void UnionFind::undoSplit(TermId id) { } void UnionFind::recordOperation(OperationKind op, TermId term) { - if (op == SPLIT) { - d_newSplit = true; - } d_undoStackIndex.set(d_undoStackIndex.get() + 1); d_undoStack.push_back(Operation(op, term)); Assert (d_undoStack.size() == d_undoStackIndex); @@ -488,7 +580,7 @@ ExtractTerm Slicer::registerTerm(TNode node) { low = utils::getExtractLow(node); } if (d_nodeToId.find(n) == d_nodeToId.end()) { - TermId id = d_unionFind.addTerm(utils::getSize(n)); + TermId id = d_unionFind.addNode(utils::getSize(n)); d_nodeToId[n] = id; d_idToNode[id] = n; } @@ -500,7 +592,8 @@ ExtractTerm Slicer::registerTerm(TNode node) { void Slicer::processEquality(TNode eq) { Debug("bv-slicer") << "Slicer::processEquality: " << eq << endl; - + + registerEquality(eq); Assert (eq.getKind() == kind::EQUAL); TNode a = eq[0]; TNode b = eq[1]; @@ -511,13 +604,35 @@ void Slicer::processEquality(TNode eq) { d_unionFind.ensureSlicing(b_ex); d_unionFind.alignSlicings(a_ex, b_ex); - d_unionFind.unionTerms(a_ex, b_ex); + Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl; Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl; Debug("bv-slicer") << "Slicer::processEquality done. " << endl; } -void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) { +void Slicer::assertEquality(TNode eq) { + Assert (eq.getKind() == kind::EQUAL); + ExtractTerm a = registerTerm(eq[0]); + ExtractTerm b = registerTerm(eq[1]); + ExplanationId reason = getExplanationId(eq); + d_unionFind.unionTerms(a, b, reason); +} + +TermId Slicer::getId(TNode node) const { + __gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction >::const_iterator it = d_nodeToId.find(node); + Assert (it != d_nodeToId.end()); + return it->second; +} + +void Slicer::registerEquality(TNode eq) { + if (d_explanationToId.find(eq) == d_explanationToId.end()) { + ExplanationId id = d_explanations.size(); + d_explanations.push_back(eq); + d_explanationToId[eq] = id; + } +} + +void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation) { Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl; Index high = utils::getSize(node) - 1; @@ -528,10 +643,18 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) { low = utils::getExtractLow(node); top = node[0]; } + Assert (d_nodeToId.find(top) != d_nodeToId.end()); TermId id = d_nodeToId[top]; - NormalForm nf(high-low+1); - d_unionFind.getNormalForm(ExtractTerm(id, high, low), nf); + NormalForm nf(high-low+1); + std::vector<ExplanationId> explanation_ids; + d_unionFind.getNormalFormWithExplanation(ExtractTerm(id, high, low), nf, explanation_ids); + + for (unsigned i = 0; i < explanation_ids.size(); ++i) { + Assert (hasExplanation(explanation_ids[i])); + TNode exp = getExplanation(explanation_ids[i]); + explanation.push_back(exp); + } // construct actual extract nodes Index current_low = 0; @@ -622,25 +745,68 @@ void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) { d_numAddedEqualities += equalities.size() - 1; } -/** - * Returns the base decomposition of the current term. - * - * @param id - * - * @return - */ -Base Slicer::getTopLevelBase(TNode node) { - if (node.getKind() == kind::BITVECTOR_EXTRACT) { - node = node[0]; + +ExtractTerm UnionFind::getExtractTerm(TermId id) const { + Assert (isExtractTerm(id)); + + return (d_idToExtract.find(id))->second; +} + +bool UnionFind::isExtractTerm(TermId id) const { + return d_idToExtract.find(id) != d_idToExtract.end(); +} + +bool Slicer::hasNode(TermId id) const { + return d_idToNode.find(id) != d_idToNode.end(); +} + +Node Slicer::getNode(TermId id) const { + // if it was an extract + if (d_unionFind.isExtractTerm(id)) { + ExtractTerm extract = d_unionFind.getExtractTerm(id); + Assert (hasNode(extract.id)); + TNode node = d_idToNode.find(extract.id)->second; + Node ex = utils::mkExtract(node, extract.high, extract.low); + return ex; } - // if we haven't seen this node before it must not be sliced yet - if (d_nodeToId.find(node) == d_nodeToId.end()) { - return Base(utils::getSize(node)); + // otherwise must be a top-level term + Assert (hasNode(id)); + return (d_idToNode.find(id))->second; +} + +bool Slicer::termInEqualityEngine(TermId id) { + Node node = getNode(id); + return d_coreSolver->hasTerm(node); +} + +void Slicer::enqueueSplit(TermId id, Index i) { + Node node = getNode(id); + Node bottom = Rewriter::rewrite(utils::mkExtract(node, i -1 , 0)); + Node top = Rewriter::rewrite(utils::mkExtract(node, utils::getSize(node) - 1, i)); + Node eq = utils::mkNode(kind::EQUAL, node, utils::mkConcat(top, bottom)); + d_newSplits.push_back(eq); + Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl; +} + +void Slicer::getNewSplits(std::vector<Node>& splits) { + for (unsigned i = d_newSplitsIndex; i < d_newSplits.size(); ++i) { + splits.push_back(d_newSplits[i]); } - TermId id = d_nodeToId[node]; - Base base(d_unionFind.getBitwidth(id)); - d_unionFind.getBase(id, base, 0); - return base; + d_newSplitsIndex = d_newSplits.size(); +} + +bool Slicer::hasExplanation(ExplanationId id) const { + return id < d_explanations.size(); +} + +TNode Slicer::getExplanation(ExplanationId id) const { + Assert(hasExplanation(id)); + return d_explanations[id]; +} + +ExplanationId Slicer::getExplanationId(TNode reason) const { + Assert (d_explanationToId.find(reason) != d_explanationToId.end()); + return d_explanationToId.find(reason)->second; } std::string UnionFind::debugPrint(TermId id) { |