diff options
author | lianah <lianahady@gmail.com> | 2013-01-28 19:04:25 -0500 |
---|---|---|
committer | lianah <lianahady@gmail.com> | 2013-01-28 19:04:25 -0500 |
commit | 5aec0f36fb2e896c24ce122a79bd70678371a249 (patch) | |
tree | e267ea1a4b45af5ab386c1b0fcdebe56d2cb203e /src/theory | |
parent | df01ef792cf9806782b12bc93ddaa139d75346c0 (diff) |
compiling implementation of new slicer finished; need to add debugging information and debug.
Diffstat (limited to 'src/theory')
-rw-r--r-- | src/theory/bv/slicer.cpp | 298 | ||||
-rw-r--r-- | src/theory/bv/slicer.h | 127 | ||||
-rw-r--r-- | src/theory/bv/theory_bv.cpp | 1 | ||||
-rw-r--r-- | src/theory/bv/theory_bv_utils.h | 29 |
4 files changed, 346 insertions, 109 deletions
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp index 001427488..1596a53ee 100644 --- a/src/theory/bv/slicer.cpp +++ b/src/theory/bv/slicer.cpp @@ -17,7 +17,6 @@ **/ #include "theory/bv/slicer.h" -#include "util/utility.h" #include "theory/bv/theory_bv_utils.h" #include "theory/rewriter.h" @@ -26,6 +25,9 @@ using namespace CVC4::theory; using namespace CVC4::theory::bv; using namespace std; + +const TermId CVC4::theory::bv::UndefinedId = -1; + /** * Base * @@ -99,7 +101,24 @@ std::string Base::debugPrint() const { os << "]"; return os.str(); } - + +/** + * NormalForm + * + */ + +TermId NormalForm::getTerm(Index i, const UnionFind& uf) const { + Assert (i < base.getBitwidth()); + Index count = 0; + for (unsigned i = 0; i < decomp.size(); ++i) { + Index size = uf.getBitwidth(decomp[i]); + if ( count + size <= i && count >= i) { + return decomp[i]; + } + count += size; + } + Unreachable(); +} /** * UnionFind @@ -110,15 +129,55 @@ TermId UnionFind::addTerm(Index bitwidth) { d_nodes.push_back(node); TermId id = d_nodes.size() - 1; d_representatives.insert(id); - d_topLevelTerms.insert(id); + return id; } +/** + * At this point we assume the slicings of the two terms are properly aligned. + * + * @param t1 + * @param t2 + */ +void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2) { + Assert (t1.getBitwidth() == t2.getBitwidth()); + + NormalForm nf1(t1.getBitwidth()); + NormalForm nf2(t2.getBitwidth()); + + getNormalForm(t1, nf1); + getNormalForm(t2, nf2); -void UnionFind::merge(TermId t1, TermId t2) { + Assert (nf1.decomp.size() == nf2.decomp.size()); + Assert (nf1.base == nf2.base); + for (unsigned i = 0; i < nf1.decomp.size(); ++i) { + merge (nf1.decomp[i], nf2.decomp[i]); + } } -TermId UnionFind::find(TermId id) { + +/** + * Merge the two terms in the union find. Both t1 and t2 + * should be root terms. + * + * @param t1 + * @param t2 + */ +void UnionFind::merge(TermId t1, TermId t2) { + t1 = find(t1); + t2 = find(t2); + + if (t1 == t2) + return; + + Node n1 = getNode(t1); + Node n2 = getNode(t2); + Assert (! n1.hasChildren() && ! n2.hasChildren()); + n1.setRepr(t2); + d_representatives.erase(t1); +} + +TermId UnionFind::find(TermId id) const { Node node = getNode(id); - if (node.getRepr() != -1) + if (node.getRepr() != UndefinedId) return find(node.getRepr()); return id; } @@ -144,19 +203,18 @@ void UnionFind::split(TermId id, Index i) { // first time we split this term TermId bottom_id = addTerm(i); TermId top_id = addTerm(node.getBitwidth() - i); - node.addChildren(top_id, bottom_id); + node.setChildren(top_id, bottom_id); } else { - Index cut = node.getCutPoint(); + Index cut = node.getCutPoint(*this); if (i < cut ) - split(child1, i); + split(node.getChild(0), i); else split(node.getChild(1), i - cut); } } -void UnionFind::getNormalForm(ExtractTerm term, NormalForm& nf) { - TermId id = find(term.id); +void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) { getDecomposition(term, nf.decomp); // update nf base Index count = 0; @@ -166,7 +224,7 @@ void UnionFind::getNormalForm(ExtractTerm term, NormalForm& nf) { } } -void UnionFind::getDecomposition(ExtractTerm term, Decomposition& decomp) { +void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp) { // making sure the term is aligned TermId id = find(term.id); @@ -179,59 +237,112 @@ void UnionFind::getDecomposition(ExtractTerm term, Decomposition& decomp) { decomp.push_back(id); } - Index cut = node.getCutPoint(); + Index cut = node.getCutPoint(*this); - if (low < cut && high < cut) { + if (term.low < cut && term.high < cut) { // the extract falls entirely on the low child - ExtractTerm child_ex(node.getChild(0), high, low); + ExtractTerm child_ex(node.getChild(0), term.high, term.low); getDecomposition(child_ex, decomp); } - else if (low >= cut && high >= cut){ + else if (term.low >= cut && term.high >= cut){ // the extract falls entirely on the high child - ExtractTerm child_ex(node.getChild(1), high - cut, low - cut); + ExtractTerm child_ex(node.getChild(1), term.high - cut, term.low - cut); getDecomposition(child_ex, decomp); } else { // the extract is split over the two children - ExtractTerm low_child(node.getChild(0), cut - 1, low); + ExtractTerm low_child(node.getChild(0), cut - 1, term.low); getDecomposition(low_child, decomp); - ExtractTerm high_child(node.getChild(1), high, cut); + ExtractTerm high_child(node.getChild(1), term.high, cut); getDecomposition(high_child, decomp); } } +/** + * May cause reslicings of the decompositions. Must not assume the decompositons + * are the current normal form. + * + * @param d1 + * @param d2 + * @param common + */ +void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposition& decomp2, TermId common) { + Index common_size = getBitwidth(common); + + // find starting points of common slice + Index start1 = 0; + for (unsigned j = 0; j < decomp1.size(); ++j) { + if (decomp1[j] == common) + break; + start1 += getBitwidth(decomp1[j]); + } + + Index start2 = 0; + for (unsigned j = 0; j < decomp2.size(); ++j) { + if (decomp2[j] == common) + break; + start2 += getBitwidth(decomp2[j]); + } + start1 = start1 > start2 ? start2 : start1; + start2 = start1 > start2 ? start1 : start2; + + if (start1 + common_size <= start2) { + Index overlap = start1 + common_size - start2; + Assert (overlap > 0); + Index diff = start2 - overlap; + Assert (diff > 0); + Index granularity = utils::gcd(diff, overlap); + // split the common part + for (unsigned i = 0; i < common_size; i+= granularity) { + split(common, i); + } + } + +} + +void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) { + NormalForm nf1(term1.getBitwidth()); + NormalForm nf2(term2.getBitwidth()); + + getNormalForm(term1, nf1); + getNormalForm(term2, nf2); -void UnionFind::alignSlicings(NormalForm& nf1, NormalForm& nf2) { Assert (nf1.base.getBitwidth() == nf2.base.getBitwidth()); - // check if the two have + + // first check if the two have any common slices std::vector<TermId> intersection; - intersection(nf1.decomp, nf2.decomp, intersection); + utils::intersect(nf1.decomp, nf2.decomp, intersection); for (unsigned i = 0; i < intersection.size(); ++i) { - TermId overlap = intersection[i]; - Index start1 = 0; - Decomposition& decomp1 = nf1.decomp; - for (unsigned j = 0; j < decomp1.size(); ++j) { - if (decomp1[j] == overlap) - break; - start1 += getSize(decomp1[j]); - } + // handle common slice may change the normal form + handleCommonSlice(nf1.decomp, nf2.decomp, intersection[i]); } - Base new_cuts1 = nf1.base.diffCutPoints(nf2.base); - Base new_cuts2 = nf2.base.diffCutPoints(nf1.base); - for (unsigned i = 0; i < new_cuts.base.getBitwidth(); ++i) { - if (new_cuts1.isCutPoint(i)) { - + // we need to update the normal form which may have changed + getNormalForm(term1, nf1); + getNormalForm(term2, nf2); + + // align the cuts points of the two slicings + // FIXME: this can be done more efficiently + Base& cuts = nf1.base; + cuts.sliceWith(nf2.base); + for (unsigned i = 0; i < cuts.getBitwidth(); ++i) { + if (cuts.isCutPoint(i)) { + TermId t1 = nf1.getTerm(i, *this); + split(t1, i); + TermId t2 = nf2.getTerm(i, *this); + split(t2, i); } } - } - -void UnionFind::ensureSlicing(ExtractTerm& term) { +/** + * Given an extract term a[i:j] makes sure a is sliced + * at indices i and j. + * + * @param term + */ +void UnionFind::ensureSlicing(const ExtractTerm& term) { TermId id = find(term.id); split(id, term.high); split(id, term.low); - - } /** @@ -239,9 +350,7 @@ void UnionFind::ensureSlicing(ExtractTerm& term) { * */ - - -void Slicer::registerTerm(TNode node) { +ExtractTerm Slicer::registerTerm(TNode node) { Index low = 0, high = utils::getSize(node); TNode n = node; if (node.getKind() == kind::BITVECTOR_EXTRACT) { @@ -250,7 +359,7 @@ void Slicer::registerTerm(TNode node) { low = utils::getExtractLow(node); } if (d_nodeToId.find(n) == d_nodeToId.end()) { - id = d_uf.addTerm(utils::getSize(n)); + TermId id = d_unionFind.addTerm(utils::getSize(n)); d_nodeToId[n] = id; d_idToNode[id] = n; } @@ -259,24 +368,109 @@ void Slicer::registerTerm(TNode node) { return ExtractTerm(id, high, low); } -void Slicer::processSimpleEquality(TNode eq) { +void Slicer::processEquality(TNode eq) { Assert (eq.getKind() == kind::EQUAL); TNode a = eq[0]; TNode b = eq[1]; ExtractTerm a_ex= registerTerm(a); ExtractTerm b_ex= registerTerm(b); - NormalForm a_nf, b_nf; - d_uf.ensureSlicing(a_ex); - d_uf.ensureSlicing(b_ex); + d_unionFind.ensureSlicing(a_ex); + d_unionFind.ensureSlicing(b_ex); - d_uf.getNormalForm(a_ex, a_nf); - d_uf.getNormalForm(b_ex, b_nf); + d_unionFind.alignSlicings(a_ex, b_ex); + d_unionFind.unionTerms(a_ex, b_ex); +} - d_uf.alignSlicings(a_nf, b_nf); +void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) { + Index high = utils::getSize(node); + Index low = 0; + if (node.getKind() == kind::BITVECTOR_EXTRACT) { + high = utils::getExtractHigh(node); + low = utils::getExtractLow(node); + node = node[0]; + } + Assert (d_nodeToId.find(node) != d_nodeToId.end()); + TermId id = d_nodeToId[node]; + NormalForm nf(utils::getSize(node)); + d_unionFind.getNormalForm(ExtractTerm(id, high, low), nf); + + // construct actual extract nodes + Index current_low = 0; + Index current_high = 0; + for (unsigned i = 0; i < nf.decomp.size(); ++i) { + Index current_size = d_unionFind.getBitwidth(nf.decomp[i]); + current_high += current_size; + Node current = utils::mkExtract(node, current_high - 1, current_low); + current_low += current_size; + decomp.push_back(current); + } } -void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) const { +bool Slicer::isCoreTerm(TNode node) { + if (d_coreTermCache.find(node) == d_coreTermCache.end()) { + Kind kind = node.getKind(); + if (kind != kind::BITVECTOR_EXTRACT && + kind != kind::BITVECTOR_CONCAT && + kind != kind::EQUAL && kind != kind::NOT && + node.getMetaKind() != kind::metakind::VARIABLE && + kind != kind::CONST_BITVECTOR) { + d_coreTermCache[node] = false; + return false; + } else { + // we need to recursively check whether the term is a root term or not + bool isCore = true; + for (unsigned i = 0; i < node.getNumChildren(); ++i) { + isCore = isCore && isCoreTerm(node[i]); + } + d_coreTermCache[node] = isCore; + return isCore; + } + } + return d_coreTermCache[node]; } +void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) { + Assert (node.getKind() == kind::EQUAL); + TNode t1 = node[0]; + TNode t2 = node[1]; + + uint32_t width = utils::getSize(t1); + + Base base1(width); + if (t1.getKind() == kind::BITVECTOR_CONCAT) { + int size = -1; + // no need to count the last child since the end cut point is implicit + for (int i = t1.getNumChildren() - 1; i >= 1 ; --i) { + size = size + utils::getSize(t1[i]); + base1.sliceAt(size); + } + } + Base base2(width); + if (t2.getKind() == kind::BITVECTOR_CONCAT) { + unsigned size = -1; + for (int i = t2.getNumChildren() - 1; i >= 1; --i) { + size = size + utils::getSize(t2[i]); + base2.sliceAt(size); + } + } + + base1.sliceWith(base2); + if (!base1.isEmpty()) { + // we split the equalities according to the base + int last = 0; + for (unsigned i = 0; i < utils::getSize(t1); ++i) { + if (base1.isCutPoint(i)) { + Node extract1 = Rewriter::rewrite(utils::mkExtract(t1, i, last)); + Node extract2 = Rewriter::rewrite(utils::mkExtract(t2, i, last)); + last = i + 1; + Assert (utils::getSize(extract1) == utils::getSize(extract2)); + equalities.push_back(utils::mkNode(kind::EQUAL, extract1, extract2)); + } + } + } else { + // just return same equality + equalities.push_back(node); + } +} diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h index 288b72bac..7ab40652d 100644 --- a/src/theory/bv/slicer.h +++ b/src/theory/bv/slicer.h @@ -26,7 +26,7 @@ #include "util/bitvector.h" #include "util/statistics_registry.h" - +#include "util/index.h" #include "expr/node.h" #include "theory/bv/theory_bv_utils.h" #ifndef __CVC4__THEORY__BV__SLICER_BV_H @@ -38,10 +38,11 @@ namespace CVC4 { namespace theory { namespace bv { -typedef uint32_t TermId; -typedef uint32_t Index; +typedef Index TermId; +extern const TermId UndefinedId; + /** * Base @@ -51,13 +52,23 @@ class Base { Index d_size; std::vector<uint32_t> d_repr; public: - Base(uint32_t size); + Base(Index size); void sliceAt(Index index); void sliceWith(const Base& other); bool isCutPoint(Index index) const; void diffCutPoints(const Base& other, Base& res) const; bool isEmpty() const; std::string debugPrint() const; + Index getBitwidth() const { return d_size; } + bool operator==(const Base& other) const { + if (other.getBitwidth() != getBitwidth()) + return false; + for (unsigned i = 0; i < d_repr.size(); ++i) { + if (d_repr[i] != other.d_repr[i]) + return false; + } + return true; + } }; /** @@ -72,88 +83,80 @@ struct ExtractTerm { Index high; Index low; ExtractTerm(TermId i, Index h, Index l) - : id (i) - high(h) + : id (i), + high(h), low(l) { - Assert (h >= l && id != -1); + Assert (h >= l && id != UndefinedId); } + Index getBitwidth() const { return high - low + 1; } }; +class UnionFind; + struct NormalForm { Base base; Decomposition decomp; + NormalForm(Index bitwidth) : base(bitwidth), decomp() {} + /** + * Returns the term in the decomposition on which the index i + * falls in + * @param i + * + * @return + */ + TermId getTerm(Index i, const UnionFind& uf) const; }; class UnionFind { class Node { - TermId d_repr; - TermId d_ch1, d_ch2; Index d_bitwidth; + TermId d_ch1, d_ch2; + TermId d_repr; public: Node(Index b) : d_bitwidth(b), - d_ch1(-1), - d_ch2(-1), - d_repr(-1) + d_ch1(UndefinedId), + d_ch2(UndefinedId), + d_repr(UndefinedId) {} TermId getRepr() const { return d_repr; } Index getBitwidth() const { return d_bitwidth; } - bool hasChildren() const { return d_ch1 != -1 && d_ch2 != -1; } + bool hasChildren() const { return d_ch1 != UndefinedId && d_ch2 != UndefinedId; } TermId getChild(Index i) const { Assert (i < 2); - return i == 0? ch1 : ch2; + return i == 0? d_ch1 : d_ch2; } - Index getCutPoint() const { - Assert (d_ch1 != -1 && d_ch2 != -1); - return getNode(d_ch1).getBitwidth(); + Index getCutPoint(const UnionFind& uf) const { + Assert (d_ch1 != UndefinedId && d_ch2 != UndefinedId); + return uf.getNode(d_ch1).getBitwidth(); } void setRepr(TermId id) { - Assert (d_children.empty()); + Assert (! hasChildren()); d_repr = id; } void setChildren(TermId ch1, TermId ch2) { - Assert (d_repr == -1 && d_children.empty()); - markAsNotTopLevel(ch1); - markAsNotTopLevel(ch2); - d_children.push_back(ch1); - d_children.push_back(ch2); + Assert (d_repr == UndefinedId && !hasChildren()); + d_ch1 = ch1; + d_ch2 = ch2; } - - // void setChildren(TermId ch1, TermId ch2, TermId ch3) { - // Assert (d_repr == -1 && d_children.empty()); - // d_children.push_back(ch1); - // d_children.push_back(ch2); - // d_children.push_back(ch3); - // } - }; + /// map from TermId to the nodes that represent them std::vector<Node> d_nodes; - + /// a term is in this set if it is its own representative TermSet d_representatives; - TermSet d_topLevelTerms; - void markAsNotTopLevel(TermId id) { - if (d_topLevelTerms.find(id) != d_topLevelTerms.end()) - d_topLevelTerms.erase(id); - } - - bool isTopLevel(TermId id) { - return d_topLevelTerms.find(id) != d_topLevelTerms.end(); - } - Index getBitwidth(TermId id) { - Assert (id < d_nodes.size()); - return d_nodes[id].getBitwidth(); - } + void getDecomposition(const ExtractTerm& term, Decomposition& decomp); + void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common); public: UnionFind() @@ -162,32 +165,44 @@ public: {} TermId addTerm(Index bitwidth); + void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2); void merge(TermId t1, TermId t2); - TermId find(TermId t1); - TermId split(TermId term, Index i); - - void getNormalForm(ExtractTerm term, NormalForm& nf); - void alignSlicings(NormalForm& nf1, NormalForm& nf2); + TermId find(TermId t1) const ; + void split(TermId term, Index i); - Node getNode(TermId id) { + void getNormalForm(const ExtractTerm& term, NormalForm& nf); + void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2); + void ensureSlicing(const ExtractTerm& term); + + Node getNode(TermId id) const { Assert (id < d_nodes.size()); return d_nodes[id]; } + Index getBitwidth(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getBitwidth(); + } + }; class Slicer { __gnu_cxx::hash_map<TermId, TNode> d_idToNode; - __gnu_cxx::hash_map<TNode, TermId> d_nodeToId; - UnionFind d_unionFind(); - + __gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction> d_nodeToId; + __gnu_cxx::hash_map<TNode, bool, TNodeHashFunction> d_coreTermCache; + UnionFind d_unionFind; + ExtractTerm registerTerm(TNode node); public: Slicer() - : d_topLevelTerms(), + : d_idToNode(), + d_nodeToId(), + d_coreTermCache(), d_unionFind() {} - void getBaseDecomposition(TNode node, std::vector<Node>& decomp) const; + void getBaseDecomposition(TNode node, std::vector<Node>& decomp); void processEquality(TNode eq); + bool isCoreTerm (TNode node); + static void splitEqualities(TNode node, std::vector<Node>& equalities); }; }/* CVC4::theory::bv namespace */ diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 1c746fafa..64662d5c6 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -210,7 +210,6 @@ Node TheoryBV::ppRewrite(TNode t) void TheoryBV::presolve() { Debug("bitvector") << "TheoryBV::presolve" << endl; - d_slicer.computeCoarsestBase(); } bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory) diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h index 73a13e1ad..f87163e37 100644 --- a/src/theory/bv/theory_bv_utils.h +++ b/src/theory/bv/theory_bv_utils.h @@ -418,6 +418,35 @@ inline std::string vectorToString(const std::vector<Node>& nodes) { return out.str(); } +// FIXME: dumb code +inline void intersect(const std::vector<uint32_t>& v1, + const std::vector<uint32_t>& v2, + std::vector<uint32_t>& intersection) { + for (unsigned i = 0; i < v1.size(); ++i) { + bool found = false; + for (unsigned j = 0; j < v2.size(); ++j) { + if (v2[j] == v1[i]) { + found = true; + break; + } + } + if (found) { + intersection.push_back(v1[i]); + } + } +} + +template <class T> +inline T gcd(T a, T b) { + while (b != 0) { + T t = b; + b = a % t; + a = t; + } + return a; +} + + } } } |