summaryrefslogtreecommitdiff
path: root/src/theory/bv/slicer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/bv/slicer.cpp')
-rw-r--r--src/theory/bv/slicer.cpp861
1 files changed, 861 insertions, 0 deletions
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp
new file mode 100644
index 000000000..5d376ea50
--- /dev/null
+++ b/src/theory/bv/slicer.cpp
@@ -0,0 +1,861 @@
+/********************* */
+/*! \file slicer.h
+ ** \verbatim
+ ** Original author: lianah
+ ** Major contributors: none
+ ** Minor contributors (to current version): none
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
+ ** Courant Institute of Mathematical Sciences
+ ** New York University
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief Bitvector theory.
+ **
+ ** Bitvector theory.
+ **/
+
+#include "theory/bv/slicer.h"
+#include "theory/bv/theory_bv_utils.h"
+#include "theory/rewriter.h"
+#include "theory/bv/bv_subtheory_core.h"
+using namespace CVC4;
+using namespace CVC4::theory;
+using namespace CVC4::theory::bv;
+using namespace std;
+
+
+const TermId CVC4::theory::bv::UndefinedId = -1;
+
+/**
+ * Base
+ *
+ */
+Base::Base(uint32_t size)
+ : d_size(size),
+ d_repr(size/32 + (size % 32 == 0? 0 : 1), 0)
+{
+ Assert (d_size > 0);
+}
+
+
+void Base::sliceAt(Index index) {
+ if (index == d_size)
+ return;
+ Assert(index < d_size);
+ Index vector_index = index / 32;
+ Assert (vector_index < d_repr.size());
+ Index int_index = index % 32;
+ uint32_t bit_mask = utils::pow2(int_index);
+ d_repr[vector_index] = d_repr[vector_index] | bit_mask;
+}
+
+void Base::undoSliceAt(Index index) {
+ Index vector_index = index / 32;
+ Assert (vector_index < d_size);
+ Index int_index = index % 32;
+ uint32_t bit_mask = utils::pow2(int_index);
+ d_repr[vector_index] = d_repr[vector_index] ^ bit_mask;
+}
+
+void Base::sliceWith(const Base& other) {
+ Assert (d_size == other.d_size);
+ for (unsigned i = 0; i < d_repr.size(); ++i) {
+ d_repr[i] = d_repr[i] | other.d_repr[i];
+ }
+}
+
+bool Base::isCutPoint (Index index) const {
+ // there is an implicit cut point at the end and begining of the bv
+ if (index == d_size || index == 0)
+ return true;
+
+ Index vector_index = index / 32;
+ Assert (vector_index < d_size);
+ Index int_index = index % 32;
+ uint32_t bit_mask = utils::pow2(int_index);
+
+ return (bit_mask & d_repr[vector_index]) != 0;
+}
+
+void Base::diffCutPoints(const Base& other, Base& res) const {
+ Assert (d_size == other.d_size && res.d_size == d_size);
+ for (unsigned i = 0; i < d_repr.size(); ++i) {
+ Assert (res.d_repr[i] == 0);
+ res.d_repr[i] = d_repr[i] ^ other.d_repr[i];
+ }
+}
+
+bool Base::isEmpty() const {
+ for (unsigned i = 0; i< d_repr.size(); ++i) {
+ if (d_repr[i] != 0)
+ return false;
+ }
+ return true;
+}
+
+std::string Base::debugPrint() const {
+ std::ostringstream os;
+ os << "[";
+ bool first = true;
+ for (int i = d_size - 1; i >= 0; --i) {
+ if (isCutPoint(i)) {
+ if (first)
+ first = false;
+ else
+ os <<"| ";
+
+ os << i ;
+ }
+ }
+ os << "]";
+ return os.str();
+}
+
+/**
+ * ExtractTerm
+ *
+ */
+
+std::string ExtractTerm::debugPrint() const {
+ ostringstream os;
+ os << "id" << id << "[" << high << ":" << low <<"] ";
+ return os.str();
+}
+
+/**
+ * NormalForm
+ *
+ */
+
+std::pair<TermId, Index> NormalForm::getTerm(Index index, const UnionFind& uf) const {
+ Assert (index < base.getBitwidth());
+ Index count = 0;
+ for (unsigned i = 0; i < decomp.size(); ++i) {
+ Index size = uf.getBitwidth(decomp[i]);
+ if ( count + size > index && index >= count) {
+ return pair<TermId, Index>(decomp[i], count);
+ }
+ count += size;
+ }
+ Unreachable();
+}
+
+
+
+std::string NormalForm::debugPrint(const UnionFind& uf) const {
+ ostringstream os;
+ os << "NF " << base.debugPrint() << endl;
+ os << "(";
+ for (int i = decomp.size() - 1; i>= 0; --i) {
+ os << decomp[i] << "[" << uf.getBitwidth(decomp[i]) <<"]";
+ os << (i != 0? ", " : "");
+ }
+ os << ") \n";
+ return os.str();
+}
+/**
+ * UnionFind::Node
+ *
+ */
+
+std::string UnionFind::Node::debugPrint() const {
+ ostringstream os;
+ os << "Repr " << d_edge.repr << " ["<< d_bitwidth << "] ";
+ os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl;
+ return os.str();
+}
+
+
+/**
+ * UnionFind
+ *
+ */
+TermId UnionFind::addNode(Index bitwidth) {
+ Assert (bitwidth > 0);
+ Node node(bitwidth);
+ d_nodes.push_back(node);
+
+ ++(d_statistics.d_numNodes);
+
+ TermId id = d_nodes.size() - 1;
+ // d_representatives.insert(id);
+ ++(d_statistics.d_numRepresentatives);
+ Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl;
+ return id;
+}
+
+TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) {
+ if (isExtractTerm(topLevel)) {
+ ExtractTerm top = getExtractTerm(topLevel);
+ Index top_high = top.high;
+ Index top_low = top.low;
+ Assert (top_high - top_low + 1 > high);
+ high += top_low;
+ low += top_low;
+ topLevel = top.id;
+ }
+ ExtractTerm extract(topLevel, high, low);
+ if (d_extractToId.find(extract) != d_extractToId.end()) {
+ return d_extractToId[extract];
+ }
+
+ Assert (high >= low);
+
+ TermId id = addNode(high - low + 1);
+ d_idToExtract[id] = extract;
+ d_extractToId[extract] = id;
+ return id;
+}
+
+/**
+ * At this point we assume the slicings of the two terms are properly aligned.
+ *
+ * @param t1
+ * @param t2
+ */
+void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason) {
+ Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n"
+ << " " << t2.debugPrint() << "\n"
+ << " with reason " << reason << endl;
+ Assert (t1.getBitwidth() == t2.getBitwidth());
+
+ NormalForm nf1(t1.getBitwidth());
+ NormalForm nf2(t2.getBitwidth());
+
+ getNormalForm(t1, nf1);
+ getNormalForm(t2, nf2);
+
+ Assert (nf1.decomp.size() == nf2.decomp.size());
+ Assert (nf1.base == nf2.base);
+
+ for (unsigned i = 0; i < nf1.decomp.size(); ++i) {
+ merge (nf1.decomp[i], nf2.decomp[i], reason);
+ }
+}
+
+
+/**
+ * Merge the two terms in the union find. Both t1 and t2
+ * should be root terms.
+ *
+ * @param t1
+ * @param t2
+ */
+void UnionFind::merge(TermId t1, TermId t2, TermId reason) {
+ Debug("bv-slicer-uf") << "UnionFind::merge (" << t1 <<", " << t2 << ")" << endl;
+ ++(d_statistics.d_numMerges);
+ t1 = find(t1);
+ t2 = find(t2);
+
+ if (t1 == t2)
+ return;
+
+ Assert (! hasChildren(t1) && ! hasChildren(t2));
+ setRepr(t1, t2, reason);
+ recordOperation(UnionFind::MERGE, t1);
+ //d_representatives.erase(t1);
+ d_statistics.d_numRepresentatives += -1;
+}
+
+TermId UnionFind::find(TermId id) {
+ TermId repr = getRepr(id);
+ if (repr != UndefinedId) {
+ TermId find_id = find(repr);
+ return find_id;
+ }
+ return id;
+}
+
+TermId UnionFind::findWithExplanation(TermId id, std::vector<ExplanationId>& explanation) {
+ TermId repr = getRepr(id);
+
+ if (repr != UndefinedId) {
+ TermId reason = getReason(id);
+ Assert (reason != UndefinedId);
+ explanation.push_back(reason);
+
+ TermId find_id = findWithExplanation(repr, explanation);
+ return find_id;
+ }
+ return id;
+}
+
+
+/**
+ * Splits the representative of the term between i-1 and i
+ *
+ * @param id the id of the term
+ * @param i the index we are splitting at
+ *
+ * @return
+ */
+void UnionFind::split(TermId id, Index i) {
+ Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl;
+ id = find(id);
+ Debug("bv-slicer-uf") << " node: " << d_nodes[id].debugPrint() << endl;
+
+ if (i == 0 || i == getBitwidth(id)) {
+ // nothing to do
+ return;
+ }
+
+ Assert (i < getBitwidth(id));
+ if (!hasChildren(id)) {
+ // first time we split this term
+ TermId bottom_id = addExtract(id, i - 1, 0);
+ TermId top_id = addExtract(id, getBitwidth(id) - 1, i);
+ setChildren(id, top_id, bottom_id);
+ recordOperation(UnionFind::SPLIT, id);
+
+ if (d_slicer->termInEqualityEngine(id)) {
+ d_slicer->enqueueSplit(id, i, top_id, bottom_id);
+ }
+ } else {
+ Index cut = getCutPoint(id);
+ if (i < cut )
+ split(getChild(id, 0), i);
+ else
+ split(getChild(id, 1), i - cut);
+ }
+ ++(d_statistics.d_numSplits);
+}
+
+// TermId UnionFind::getTopLevel(TermId id) const {
+// __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> >::const_iterator it = d_idToExtract.find(id);
+// if (it != d_idToExtract.end()) {
+// return (*it).second.id;
+// }
+// return id;
+// }
+
+void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) {
+ nf.clear();
+ getDecomposition(term, nf.decomp);
+ // update nf base
+ Index count = 0;
+ for (unsigned i = 0; i < nf.decomp.size(); ++i) {
+ count += getBitwidth(nf.decomp[i]);
+ nf.base.sliceAt(count);
+ }
+ Debug("bv-slicer-uf") << "UnionFind::getNormalFrom term: " << term.debugPrint() << endl;
+ Debug("bv-slicer-uf") << " nf: " << nf.debugPrint(*this) << endl;
+}
+
+void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp) {
+ // making sure the term is aligned
+ TermId id = find(term.id);
+
+ Assert (term.high < getBitwidth(id));
+ // because we split the node, this must be the whole extract
+ if (!hasChildren(id)) {
+ Assert (term.high == getBitwidth(id) - 1 &&
+ term.low == 0);
+ decomp.push_back(id);
+ return;
+ }
+
+ Index cut = getCutPoint(id);
+
+ if (term.low < cut && term.high < cut) {
+ // the extract falls entirely on the low child
+ ExtractTerm child_ex(getChild(id, 0), term.high, term.low);
+ getDecomposition(child_ex, decomp);
+ }
+ else if (term.low >= cut && term.high >= cut){
+ // the extract falls entirely on the high child
+ ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut);
+ getDecomposition(child_ex, decomp);
+ }
+ else {
+ // the extract is split over the two children
+ ExtractTerm low_child(getChild(id, 0), cut - 1, term.low);
+ getDecomposition(low_child, decomp);
+ ExtractTerm high_child(getChild(id, 1), term.high - cut, 0);
+ getDecomposition(high_child, decomp);
+ }
+}
+
+void UnionFind::getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf,
+ std::vector<ExplanationId>& explanation) {
+ nf.clear();
+ getDecompositionWithExplanation(term, nf.decomp, explanation);
+ // update nf base
+ Index count = 0;
+ for (unsigned i = 0; i < nf.decomp.size(); ++i) {
+ count += getBitwidth(nf.decomp[i]);
+ nf.base.sliceAt(count);
+ }
+ Debug("bv-slicer-uf") << "UnionFind::getNormalFrom term: " << term.debugPrint() << endl;
+ Debug("bv-slicer-uf") << " nf: " << nf.debugPrint(*this) << endl;
+}
+
+void UnionFind::getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp,
+ std::vector<ExplanationId>& explanation) {
+ // making sure the term is aligned
+ TermId id = findWithExplanation(term.id, explanation);
+
+ Assert (term.high < getBitwidth(id));
+ // because we split the node, this must be the whole extract
+ if (!hasChildren(id)) {
+ Assert (term.high == getBitwidth(id) - 1 &&
+ term.low == 0);
+ decomp.push_back(id);
+ return;
+ }
+
+ Index cut = getCutPoint(id);
+
+ if (term.low < cut && term.high < cut) {
+ // the extract falls entirely on the low child
+ ExtractTerm child_ex(getChild(id, 0), term.high, term.low);
+ getDecompositionWithExplanation(child_ex, decomp, explanation);
+ }
+ else if (term.low >= cut && term.high >= cut){
+ // the extract falls entirely on the high child
+ ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut);
+ getDecompositionWithExplanation(child_ex, decomp, explanation);
+ }
+ else {
+ // the extract is split over the two children
+ ExtractTerm low_child(getChild(id, 0), cut - 1, term.low);
+ getDecompositionWithExplanation(low_child, decomp, explanation);
+ ExtractTerm high_child(getChild(id, 1), term.high - cut, 0);
+ getDecompositionWithExplanation(high_child, decomp, explanation);
+ }
+}
+
+/**
+ * May cause reslicings of the decompositions. Must not assume the decompositons
+ * are the current normal form.
+ *
+ * @param d1
+ * @param d2
+ * @param common
+ */
+void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposition& decomp2, TermId common) {
+ Debug("bv-slicer") << "UnionFind::handleCommonSlice common = " << common << endl;
+ Index common_size = getBitwidth(common);
+ // find starting points of common slice
+ Index start1 = 0;
+ for (unsigned j = 0; j < decomp1.size(); ++j) {
+ if (decomp1[j] == common)
+ break;
+ start1 += getBitwidth(decomp1[j]);
+ }
+
+ Index start2 = 0;
+ for (unsigned j = 0; j < decomp2.size(); ++j) {
+ if (decomp2[j] == common)
+ break;
+ start2 += getBitwidth(decomp2[j]);
+ }
+ if (start1 > start2) {
+ Index temp = start1;
+ start1 = start2;
+ start2 = temp;
+ }
+
+ if (start2 - start1 < common_size) {
+ Index overlap = start1 + common_size - start2;
+ Assert (overlap > 0);
+ Index diff = common_size - overlap;
+ Assert (diff >= 0);
+ Index granularity = utils::gcd(diff, overlap);
+ // split the common part
+ for (unsigned i = 0; i < common_size; i+= granularity) {
+ split(common, i);
+ }
+ }
+
+}
+
+void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) {
+ Debug("bv-slicer") << "UnionFind::alignSlicings " << term1.debugPrint() << endl;
+ Debug("bv-slicer") << " " << term2.debugPrint() << endl;
+ NormalForm nf1(term1.getBitwidth());
+ NormalForm nf2(term2.getBitwidth());
+
+ getNormalForm(term1, nf1);
+ getNormalForm(term2, nf2);
+
+ Assert (nf1.base.getBitwidth() == nf2.base.getBitwidth());
+
+ // first check if the two have any common slices
+ std::vector<TermId> intersection;
+ utils::intersect(nf1.decomp, nf2.decomp, intersection);
+ for (unsigned i = 0; i < intersection.size(); ++i) {
+ // handle common slice may change the normal form
+ handleCommonSlice(nf1.decomp, nf2.decomp, intersection[i]);
+ }
+ // propagate cuts to a fixpoint
+ bool changed;
+ Base cuts(term1.getBitwidth());
+ do {
+ changed = false;
+ // we need to update the normal form which may have changed
+ getNormalForm(term1, nf1);
+ getNormalForm(term2, nf2);
+
+ // align the cuts points of the two slicings
+ // FIXME: this can be done more efficiently
+ cuts.sliceWith(nf1.base);
+ cuts.sliceWith(nf2.base);
+
+ for (unsigned i = 0; i < cuts.getBitwidth(); ++i) {
+ if (cuts.isCutPoint(i)) {
+ if (!nf1.base.isCutPoint(i)) {
+ pair<TermId, Index> pair1 = nf1.getTerm(i, *this);
+ split(pair1.first, i - pair1.second);
+ changed = true;
+ }
+ if (!nf2.base.isCutPoint(i)) {
+ pair<TermId, Index> pair2 = nf2.getTerm(i, *this);
+ split(pair2.first, i - pair2.second);
+ changed = true;
+ }
+ }
+ }
+ } while (changed);
+}
+/**
+ * Given an extract term a[i:j] makes sure a is sliced
+ * at indices i and j.
+ *
+ * @param term
+ */
+void UnionFind::ensureSlicing(const ExtractTerm& term) {
+ //Debug("bv-slicer") << "Slicer::ensureSlicing " << term.debugPrint() << endl;
+ TermId id = find(term.id);
+ split(id, term.high + 1);
+ split(id, term.low);
+}
+
+void UnionFind::backtrack() {
+ int size = d_undoStack.size();
+ for (int i = size; i > (int)d_undoStackIndex.get(); --i) {
+ Operation op = d_undoStack.back();
+ Assert (!d_undoStack.empty());
+ d_undoStack.pop_back();
+ if (op.op == UnionFind::MERGE) {
+ undoMerge(op.id);
+ } else {
+ Assert (op.op == UnionFind::SPLIT);
+ undoSplit(op.id);
+ }
+ }
+}
+
+void UnionFind::undoMerge(TermId id) {
+ TermId repr = getRepr(id);
+ Assert (repr != UndefinedId);
+ setRepr(id, UndefinedId, UndefinedId);
+}
+
+void UnionFind::undoSplit(TermId id) {
+ Assert (hasChildren(id));
+ setChildren(id, UndefinedId, UndefinedId);
+}
+
+void UnionFind::recordOperation(OperationKind op, TermId term) {
+ d_undoStackIndex.set(d_undoStackIndex.get() + 1);
+ d_undoStack.push_back(Operation(op, term));
+ Assert (d_undoStack.size() == d_undoStackIndex);
+}
+
+void UnionFind::getBase(TermId id, Base& base, Index offset) {
+ id = find(id);
+ if (!hasChildren(id))
+ return;
+ TermId id1 = find(getChild(id, 1));
+ TermId id0 = find(getChild(id, 0));
+ Index cut = getCutPoint(id);
+ base.sliceAt(cut + offset);
+ getBase(id1, base, cut + offset);
+ getBase(id0, base, offset);
+}
+
+
+/**
+ * Slicer
+ *
+ */
+
+ExtractTerm Slicer::registerTerm(TNode node) {
+ Index low = 0, high = utils::getSize(node) - 1;
+ TNode n = node;
+ if (node.getKind() == kind::BITVECTOR_EXTRACT) {
+ n = node[0];
+ high = utils::getExtractHigh(node);
+ low = utils::getExtractLow(node);
+ }
+ if (d_nodeToId.find(n) == d_nodeToId.end()) {
+ TermId id = d_unionFind.addNode(utils::getSize(n));
+ d_nodeToId[n] = id;
+ d_idToNode[id] = n;
+ }
+ TermId id = d_nodeToId[n];
+ d_unionFind.addExtract(id, high, low);
+ ExtractTerm res(id, high, low);
+ Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl;
+ return res;
+}
+
+void Slicer::processEquality(TNode eq) {
+ Debug("bv-slicer") << "Slicer::processEquality: " << eq << endl;
+
+ registerEquality(eq);
+ Assert (eq.getKind() == kind::EQUAL);
+ TNode a = eq[0];
+ TNode b = eq[1];
+ ExtractTerm a_ex= registerTerm(a);
+ ExtractTerm b_ex= registerTerm(b);
+
+ d_unionFind.ensureSlicing(a_ex);
+ d_unionFind.ensureSlicing(b_ex);
+
+ d_unionFind.alignSlicings(a_ex, b_ex);
+
+ Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl;
+ Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl;
+ Debug("bv-slicer") << "Slicer::processEquality done. " << endl;
+}
+
+void Slicer::assertEquality(TNode eq) {
+ Assert (eq.getKind() == kind::EQUAL);
+ ExtractTerm a = registerTerm(eq[0]);
+ ExtractTerm b = registerTerm(eq[1]);
+ ExplanationId reason = getExplanationId(eq);
+ d_unionFind.unionTerms(a, b, reason);
+}
+
+TermId Slicer::getId(TNode node) const {
+ __gnu_cxx::hash_map<Node, TermId, NodeHashFunction >::const_iterator it = d_nodeToId.find(node);
+ Assert (it != d_nodeToId.end());
+ return it->second;
+}
+
+void Slicer::registerEquality(TNode eq) {
+ if (d_explanationToId.find(eq) == d_explanationToId.end()) {
+ ExplanationId id = d_explanations.size();
+ d_explanations.push_back(eq);
+ d_explanationToId[eq] = id;
+ }
+}
+
+void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation) {
+ Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl;
+
+ Index high = utils::getSize(node) - 1;
+ Index low = 0;
+ TNode top = node;
+ if (node.getKind() == kind::BITVECTOR_EXTRACT) {
+ high = utils::getExtractHigh(node);
+ low = utils::getExtractLow(node);
+ top = node[0];
+ }
+
+ Assert (d_nodeToId.find(top) != d_nodeToId.end());
+ TermId id = d_nodeToId[top];
+ NormalForm nf(high-low+1);
+ std::vector<ExplanationId> explanation_ids;
+ d_unionFind.getNormalFormWithExplanation(ExtractTerm(id, high, low), nf, explanation_ids);
+
+ for (unsigned i = 0; i < explanation_ids.size(); ++i) {
+ Assert (hasExplanation(explanation_ids[i]));
+ TNode exp = getExplanation(explanation_ids[i]);
+ explanation.push_back(exp);
+ }
+
+ for (int i = nf.decomp.size() - 1; i>=0 ; --i) {
+ Node current = getNode(nf.decomp[i]);
+ decomp.push_back(current);
+ }
+
+
+ Debug("bv-slicer") << "as [";
+ for (unsigned i = 0; i < decomp.size(); ++i) {
+ Debug("bv-slicer") << decomp[i] <<" ";
+ }
+ Debug("bv-slicer") << "]" << endl;
+
+}
+
+bool Slicer::isCoreTerm(TNode node) {
+ if (d_coreTermCache.find(node) == d_coreTermCache.end()) {
+ Kind kind = node.getKind();
+ if (kind != kind::BITVECTOR_EXTRACT &&
+ kind != kind::BITVECTOR_CONCAT &&
+ kind != kind::EQUAL && kind != kind::NOT &&
+ node.getMetaKind() != kind::metakind::VARIABLE &&
+ kind != kind::CONST_BITVECTOR) {
+ d_coreTermCache[node] = false;
+ return false;
+ } else {
+ // we need to recursively check whether the term is a root term or not
+ bool isCore = true;
+ for (unsigned i = 0; i < node.getNumChildren(); ++i) {
+ isCore = isCore && isCoreTerm(node[i]);
+ }
+ d_coreTermCache[node] = isCore;
+ return isCore;
+ }
+ }
+ return d_coreTermCache[node];
+}
+unsigned Slicer::d_numAddedEqualities = 0;
+
+void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) {
+ Assert (node.getKind() == kind::EQUAL);
+ TNode t1 = node[0];
+ TNode t2 = node[1];
+
+ uint32_t width = utils::getSize(t1);
+
+ Base base1(width);
+ if (t1.getKind() == kind::BITVECTOR_CONCAT) {
+ int size = 0;
+ // no need to count the last child since the end cut point is implicit
+ for (int i = t1.getNumChildren() - 1; i >= 1 ; --i) {
+ size = size + utils::getSize(t1[i]);
+ base1.sliceAt(size);
+ }
+ }
+
+ Base base2(width);
+ if (t2.getKind() == kind::BITVECTOR_CONCAT) {
+ unsigned size = 0;
+ for (int i = t2.getNumChildren() - 1; i >= 1; --i) {
+ size = size + utils::getSize(t2[i]);
+ base2.sliceAt(size);
+ }
+ }
+
+ base1.sliceWith(base2);
+ if (!base1.isEmpty()) {
+ // we split the equalities according to the base
+ int last = 0;
+ for (unsigned i = 1; i <= utils::getSize(t1); ++i) {
+ if (base1.isCutPoint(i)) {
+ Node extract1 = utils::mkExtract(t1, i-1, last);
+ Node extract2 = utils::mkExtract(t2, i-1, last);
+ last = i;
+ Assert (utils::getSize(extract1) == utils::getSize(extract2));
+ equalities.push_back(utils::mkNode(kind::EQUAL, extract1, extract2));
+ }
+ }
+ } else {
+ // just return same equality
+ equalities.push_back(node);
+ }
+ d_numAddedEqualities += equalities.size() - 1;
+}
+
+
+ExtractTerm UnionFind::getExtractTerm(TermId id) const {
+ Assert (isExtractTerm(id));
+
+ return (d_idToExtract.find(id))->second;
+}
+
+bool UnionFind::isExtractTerm(TermId id) const {
+ return d_idToExtract.find(id) != d_idToExtract.end();
+}
+
+bool Slicer::hasNode(TermId id) const {
+ return d_idToNode.find(id) != d_idToNode.end();
+}
+
+Node Slicer::getNode(TermId id) const {
+ if (hasNode(id)) {
+ return d_idToNode.find(id)->second;
+ }
+ // otherwise must be an extract
+ Assert (d_unionFind.isExtractTerm(id));
+ ExtractTerm extract = d_unionFind.getExtractTerm(id);
+ Assert (hasNode(extract.id));
+ TNode node = d_idToNode.find(extract.id)->second;
+ Node ex = utils::mkExtract(node, extract.high, extract.low);
+ return ex;
+}
+
+bool Slicer::termInEqualityEngine(TermId id) {
+ Node node = getNode(id);
+ return d_coreSolver->hasTerm(node);
+}
+
+void Slicer::enqueueSplit(TermId id, Index i, TermId top_id, TermId bottom_id) {
+ Node node = getNode(id);
+ Node bottom = Rewriter::rewrite(utils::mkExtract(node, i -1 , 0));
+ Node top = Rewriter::rewrite(utils::mkExtract(node, utils::getSize(node) - 1, i));
+ // must add terms to equality engine so we get notified when they get split more
+ d_coreSolver->addTermToEqualityEngine(bottom);
+ d_coreSolver->addTermToEqualityEngine(top);
+
+ Node eq = utils::mkNode(kind::EQUAL, node, utils::mkConcat(top, bottom));
+ d_newSplits.push_back(eq);
+ Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl;
+ Debug("bv-slicer") << " " << id << "=" << top_id << " " << bottom_id << endl;
+}
+
+void Slicer::getNewSplits(std::vector<Node>& splits) {
+ for (unsigned i = d_newSplitsIndex; i < d_newSplits.size(); ++i) {
+ splits.push_back(d_newSplits[i]);
+ }
+ d_newSplitsIndex = d_newSplits.size();
+}
+
+bool Slicer::hasExplanation(ExplanationId id) const {
+ return id < d_explanations.size();
+}
+
+TNode Slicer::getExplanation(ExplanationId id) const {
+ Assert(hasExplanation(id));
+ return d_explanations[id];
+}
+
+ExplanationId Slicer::getExplanationId(TNode reason) const {
+ Assert (d_explanationToId.find(reason) != d_explanationToId.end());
+ return d_explanationToId.find(reason)->second;
+}
+
+std::string UnionFind::debugPrint(TermId id) {
+ ostringstream os;
+ if (hasChildren(id)) {
+ TermId id1 = find(getChild(id, 1));
+ TermId id0 = find(getChild(id, 0));
+ os << debugPrint(id1);
+ os << debugPrint(id0);
+ } else {
+ if (getRepr(id) == UndefinedId) {
+ os <<"id"<< id <<"[" << getBitwidth(id) <<"] ";
+ } else {
+ os << debugPrint(find(id));
+ }
+ }
+ return os.str();
+}
+
+UnionFind::Statistics::Statistics():
+ d_numNodes("theory::bv::slicer::NumberOfNodes", 0),
+ d_numRepresentatives("theory::bv::slicer::NumberOfRepresentatives", 0),
+ d_numSplits("theory::bv::slicer::NumberOfSplits", 0),
+ d_numMerges("theory::bv::slicer::NumberOfMerges", 0),
+ d_avgFindDepth("theory::bv::slicer::AverageFindDepth"),
+ d_numAddedEqualities("theory::bv::slicer::NumberOfEqualitiesAdded", Slicer::d_numAddedEqualities)
+{
+ StatisticsRegistry::registerStat(&d_numRepresentatives);
+ StatisticsRegistry::registerStat(&d_numSplits);
+ StatisticsRegistry::registerStat(&d_numMerges);
+ StatisticsRegistry::registerStat(&d_avgFindDepth);
+ StatisticsRegistry::registerStat(&d_numAddedEqualities);
+}
+
+UnionFind::Statistics::~Statistics() {
+ StatisticsRegistry::unregisterStat(&d_numRepresentatives);
+ StatisticsRegistry::unregisterStat(&d_numSplits);
+ StatisticsRegistry::unregisterStat(&d_numMerges);
+ StatisticsRegistry::unregisterStat(&d_avgFindDepth);
+ StatisticsRegistry::unregisterStat(&d_numAddedEqualities);
+}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback