summaryrefslogtreecommitdiff
path: root/src/theory/bv
diff options
context:
space:
mode:
authorlianah <lianahady@gmail.com>2013-01-28 19:04:25 -0500
committerlianah <lianahady@gmail.com>2013-01-28 19:04:25 -0500
commit5aec0f36fb2e896c24ce122a79bd70678371a249 (patch)
treee267ea1a4b45af5ab386c1b0fcdebe56d2cb203e /src/theory/bv
parentdf01ef792cf9806782b12bc93ddaa139d75346c0 (diff)
compiling implementation of new slicer finished; need to add debugging information and debug.
Diffstat (limited to 'src/theory/bv')
-rw-r--r--src/theory/bv/slicer.cpp298
-rw-r--r--src/theory/bv/slicer.h127
-rw-r--r--src/theory/bv/theory_bv.cpp1
-rw-r--r--src/theory/bv/theory_bv_utils.h29
4 files changed, 346 insertions, 109 deletions
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp
index 001427488..1596a53ee 100644
--- a/src/theory/bv/slicer.cpp
+++ b/src/theory/bv/slicer.cpp
@@ -17,7 +17,6 @@
**/
#include "theory/bv/slicer.h"
-#include "util/utility.h"
#include "theory/bv/theory_bv_utils.h"
#include "theory/rewriter.h"
@@ -26,6 +25,9 @@ using namespace CVC4::theory;
using namespace CVC4::theory::bv;
using namespace std;
+
+const TermId CVC4::theory::bv::UndefinedId = -1;
+
/**
* Base
*
@@ -99,7 +101,24 @@ std::string Base::debugPrint() const {
os << "]";
return os.str();
}
-
+
+/**
+ * NormalForm
+ *
+ */
+
+TermId NormalForm::getTerm(Index i, const UnionFind& uf) const {
+ Assert (i < base.getBitwidth());
+ Index count = 0;
+ for (unsigned i = 0; i < decomp.size(); ++i) {
+ Index size = uf.getBitwidth(decomp[i]);
+ if ( count + size <= i && count >= i) {
+ return decomp[i];
+ }
+ count += size;
+ }
+ Unreachable();
+}
/**
* UnionFind
@@ -110,15 +129,55 @@ TermId UnionFind::addTerm(Index bitwidth) {
d_nodes.push_back(node);
TermId id = d_nodes.size() - 1;
d_representatives.insert(id);
- d_topLevelTerms.insert(id);
+ return id;
}
+/**
+ * At this point we assume the slicings of the two terms are properly aligned.
+ *
+ * @param t1
+ * @param t2
+ */
+void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2) {
+ Assert (t1.getBitwidth() == t2.getBitwidth());
+
+ NormalForm nf1(t1.getBitwidth());
+ NormalForm nf2(t2.getBitwidth());
+
+ getNormalForm(t1, nf1);
+ getNormalForm(t2, nf2);
-void UnionFind::merge(TermId t1, TermId t2) {
+ Assert (nf1.decomp.size() == nf2.decomp.size());
+ Assert (nf1.base == nf2.base);
+ for (unsigned i = 0; i < nf1.decomp.size(); ++i) {
+ merge (nf1.decomp[i], nf2.decomp[i]);
+ }
}
-TermId UnionFind::find(TermId id) {
+
+/**
+ * Merge the two terms in the union find. Both t1 and t2
+ * should be root terms.
+ *
+ * @param t1
+ * @param t2
+ */
+void UnionFind::merge(TermId t1, TermId t2) {
+ t1 = find(t1);
+ t2 = find(t2);
+
+ if (t1 == t2)
+ return;
+
+ Node n1 = getNode(t1);
+ Node n2 = getNode(t2);
+ Assert (! n1.hasChildren() && ! n2.hasChildren());
+ n1.setRepr(t2);
+ d_representatives.erase(t1);
+}
+
+TermId UnionFind::find(TermId id) const {
Node node = getNode(id);
- if (node.getRepr() != -1)
+ if (node.getRepr() != UndefinedId)
return find(node.getRepr());
return id;
}
@@ -144,19 +203,18 @@ void UnionFind::split(TermId id, Index i) {
// first time we split this term
TermId bottom_id = addTerm(i);
TermId top_id = addTerm(node.getBitwidth() - i);
- node.addChildren(top_id, bottom_id);
+ node.setChildren(top_id, bottom_id);
} else {
- Index cut = node.getCutPoint();
+ Index cut = node.getCutPoint(*this);
if (i < cut )
- split(child1, i);
+ split(node.getChild(0), i);
else
split(node.getChild(1), i - cut);
}
}
-void UnionFind::getNormalForm(ExtractTerm term, NormalForm& nf) {
- TermId id = find(term.id);
+void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) {
getDecomposition(term, nf.decomp);
// update nf base
Index count = 0;
@@ -166,7 +224,7 @@ void UnionFind::getNormalForm(ExtractTerm term, NormalForm& nf) {
}
}
-void UnionFind::getDecomposition(ExtractTerm term, Decomposition& decomp) {
+void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp) {
// making sure the term is aligned
TermId id = find(term.id);
@@ -179,59 +237,112 @@ void UnionFind::getDecomposition(ExtractTerm term, Decomposition& decomp) {
decomp.push_back(id);
}
- Index cut = node.getCutPoint();
+ Index cut = node.getCutPoint(*this);
- if (low < cut && high < cut) {
+ if (term.low < cut && term.high < cut) {
// the extract falls entirely on the low child
- ExtractTerm child_ex(node.getChild(0), high, low);
+ ExtractTerm child_ex(node.getChild(0), term.high, term.low);
getDecomposition(child_ex, decomp);
}
- else if (low >= cut && high >= cut){
+ else if (term.low >= cut && term.high >= cut){
// the extract falls entirely on the high child
- ExtractTerm child_ex(node.getChild(1), high - cut, low - cut);
+ ExtractTerm child_ex(node.getChild(1), term.high - cut, term.low - cut);
getDecomposition(child_ex, decomp);
}
else {
// the extract is split over the two children
- ExtractTerm low_child(node.getChild(0), cut - 1, low);
+ ExtractTerm low_child(node.getChild(0), cut - 1, term.low);
getDecomposition(low_child, decomp);
- ExtractTerm high_child(node.getChild(1), high, cut);
+ ExtractTerm high_child(node.getChild(1), term.high, cut);
getDecomposition(high_child, decomp);
}
}
+/**
+ * May cause reslicings of the decompositions. Must not assume the decompositons
+ * are the current normal form.
+ *
+ * @param d1
+ * @param d2
+ * @param common
+ */
+void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposition& decomp2, TermId common) {
+ Index common_size = getBitwidth(common);
+
+ // find starting points of common slice
+ Index start1 = 0;
+ for (unsigned j = 0; j < decomp1.size(); ++j) {
+ if (decomp1[j] == common)
+ break;
+ start1 += getBitwidth(decomp1[j]);
+ }
+
+ Index start2 = 0;
+ for (unsigned j = 0; j < decomp2.size(); ++j) {
+ if (decomp2[j] == common)
+ break;
+ start2 += getBitwidth(decomp2[j]);
+ }
+ start1 = start1 > start2 ? start2 : start1;
+ start2 = start1 > start2 ? start1 : start2;
+
+ if (start1 + common_size <= start2) {
+ Index overlap = start1 + common_size - start2;
+ Assert (overlap > 0);
+ Index diff = start2 - overlap;
+ Assert (diff > 0);
+ Index granularity = utils::gcd(diff, overlap);
+ // split the common part
+ for (unsigned i = 0; i < common_size; i+= granularity) {
+ split(common, i);
+ }
+ }
+
+}
+
+void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) {
+ NormalForm nf1(term1.getBitwidth());
+ NormalForm nf2(term2.getBitwidth());
+
+ getNormalForm(term1, nf1);
+ getNormalForm(term2, nf2);
-void UnionFind::alignSlicings(NormalForm& nf1, NormalForm& nf2) {
Assert (nf1.base.getBitwidth() == nf2.base.getBitwidth());
- // check if the two have
+
+ // first check if the two have any common slices
std::vector<TermId> intersection;
- intersection(nf1.decomp, nf2.decomp, intersection);
+ utils::intersect(nf1.decomp, nf2.decomp, intersection);
for (unsigned i = 0; i < intersection.size(); ++i) {
- TermId overlap = intersection[i];
- Index start1 = 0;
- Decomposition& decomp1 = nf1.decomp;
- for (unsigned j = 0; j < decomp1.size(); ++j) {
- if (decomp1[j] == overlap)
- break;
- start1 += getSize(decomp1[j]);
- }
+ // handle common slice may change the normal form
+ handleCommonSlice(nf1.decomp, nf2.decomp, intersection[i]);
}
- Base new_cuts1 = nf1.base.diffCutPoints(nf2.base);
- Base new_cuts2 = nf2.base.diffCutPoints(nf1.base);
- for (unsigned i = 0; i < new_cuts.base.getBitwidth(); ++i) {
- if (new_cuts1.isCutPoint(i)) {
-
+ // we need to update the normal form which may have changed
+ getNormalForm(term1, nf1);
+ getNormalForm(term2, nf2);
+
+ // align the cuts points of the two slicings
+ // FIXME: this can be done more efficiently
+ Base& cuts = nf1.base;
+ cuts.sliceWith(nf2.base);
+ for (unsigned i = 0; i < cuts.getBitwidth(); ++i) {
+ if (cuts.isCutPoint(i)) {
+ TermId t1 = nf1.getTerm(i, *this);
+ split(t1, i);
+ TermId t2 = nf2.getTerm(i, *this);
+ split(t2, i);
}
}
-
}
-
-void UnionFind::ensureSlicing(ExtractTerm& term) {
+/**
+ * Given an extract term a[i:j] makes sure a is sliced
+ * at indices i and j.
+ *
+ * @param term
+ */
+void UnionFind::ensureSlicing(const ExtractTerm& term) {
TermId id = find(term.id);
split(id, term.high);
split(id, term.low);
-
-
}
/**
@@ -239,9 +350,7 @@ void UnionFind::ensureSlicing(ExtractTerm& term) {
*
*/
-
-
-void Slicer::registerTerm(TNode node) {
+ExtractTerm Slicer::registerTerm(TNode node) {
Index low = 0, high = utils::getSize(node);
TNode n = node;
if (node.getKind() == kind::BITVECTOR_EXTRACT) {
@@ -250,7 +359,7 @@ void Slicer::registerTerm(TNode node) {
low = utils::getExtractLow(node);
}
if (d_nodeToId.find(n) == d_nodeToId.end()) {
- id = d_uf.addTerm(utils::getSize(n));
+ TermId id = d_unionFind.addTerm(utils::getSize(n));
d_nodeToId[n] = id;
d_idToNode[id] = n;
}
@@ -259,24 +368,109 @@ void Slicer::registerTerm(TNode node) {
return ExtractTerm(id, high, low);
}
-void Slicer::processSimpleEquality(TNode eq) {
+void Slicer::processEquality(TNode eq) {
Assert (eq.getKind() == kind::EQUAL);
TNode a = eq[0];
TNode b = eq[1];
ExtractTerm a_ex= registerTerm(a);
ExtractTerm b_ex= registerTerm(b);
- NormalForm a_nf, b_nf;
- d_uf.ensureSlicing(a_ex);
- d_uf.ensureSlicing(b_ex);
+ d_unionFind.ensureSlicing(a_ex);
+ d_unionFind.ensureSlicing(b_ex);
- d_uf.getNormalForm(a_ex, a_nf);
- d_uf.getNormalForm(b_ex, b_nf);
+ d_unionFind.alignSlicings(a_ex, b_ex);
+ d_unionFind.unionTerms(a_ex, b_ex);
+}
- d_uf.alignSlicings(a_nf, b_nf);
+void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) {
+ Index high = utils::getSize(node);
+ Index low = 0;
+ if (node.getKind() == kind::BITVECTOR_EXTRACT) {
+ high = utils::getExtractHigh(node);
+ low = utils::getExtractLow(node);
+ node = node[0];
+ }
+ Assert (d_nodeToId.find(node) != d_nodeToId.end());
+ TermId id = d_nodeToId[node];
+ NormalForm nf(utils::getSize(node));
+ d_unionFind.getNormalForm(ExtractTerm(id, high, low), nf);
+
+ // construct actual extract nodes
+ Index current_low = 0;
+ Index current_high = 0;
+ for (unsigned i = 0; i < nf.decomp.size(); ++i) {
+ Index current_size = d_unionFind.getBitwidth(nf.decomp[i]);
+ current_high += current_size;
+ Node current = utils::mkExtract(node, current_high - 1, current_low);
+ current_low += current_size;
+ decomp.push_back(current);
+ }
}
-void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) const {
+bool Slicer::isCoreTerm(TNode node) {
+ if (d_coreTermCache.find(node) == d_coreTermCache.end()) {
+ Kind kind = node.getKind();
+ if (kind != kind::BITVECTOR_EXTRACT &&
+ kind != kind::BITVECTOR_CONCAT &&
+ kind != kind::EQUAL && kind != kind::NOT &&
+ node.getMetaKind() != kind::metakind::VARIABLE &&
+ kind != kind::CONST_BITVECTOR) {
+ d_coreTermCache[node] = false;
+ return false;
+ } else {
+ // we need to recursively check whether the term is a root term or not
+ bool isCore = true;
+ for (unsigned i = 0; i < node.getNumChildren(); ++i) {
+ isCore = isCore && isCoreTerm(node[i]);
+ }
+ d_coreTermCache[node] = isCore;
+ return isCore;
+ }
+ }
+ return d_coreTermCache[node];
}
+void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) {
+ Assert (node.getKind() == kind::EQUAL);
+ TNode t1 = node[0];
+ TNode t2 = node[1];
+
+ uint32_t width = utils::getSize(t1);
+
+ Base base1(width);
+ if (t1.getKind() == kind::BITVECTOR_CONCAT) {
+ int size = -1;
+ // no need to count the last child since the end cut point is implicit
+ for (int i = t1.getNumChildren() - 1; i >= 1 ; --i) {
+ size = size + utils::getSize(t1[i]);
+ base1.sliceAt(size);
+ }
+ }
+ Base base2(width);
+ if (t2.getKind() == kind::BITVECTOR_CONCAT) {
+ unsigned size = -1;
+ for (int i = t2.getNumChildren() - 1; i >= 1; --i) {
+ size = size + utils::getSize(t2[i]);
+ base2.sliceAt(size);
+ }
+ }
+
+ base1.sliceWith(base2);
+ if (!base1.isEmpty()) {
+ // we split the equalities according to the base
+ int last = 0;
+ for (unsigned i = 0; i < utils::getSize(t1); ++i) {
+ if (base1.isCutPoint(i)) {
+ Node extract1 = Rewriter::rewrite(utils::mkExtract(t1, i, last));
+ Node extract2 = Rewriter::rewrite(utils::mkExtract(t2, i, last));
+ last = i + 1;
+ Assert (utils::getSize(extract1) == utils::getSize(extract2));
+ equalities.push_back(utils::mkNode(kind::EQUAL, extract1, extract2));
+ }
+ }
+ } else {
+ // just return same equality
+ equalities.push_back(node);
+ }
+}
diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h
index 288b72bac..7ab40652d 100644
--- a/src/theory/bv/slicer.h
+++ b/src/theory/bv/slicer.h
@@ -26,7 +26,7 @@
#include "util/bitvector.h"
#include "util/statistics_registry.h"
-
+#include "util/index.h"
#include "expr/node.h"
#include "theory/bv/theory_bv_utils.h"
#ifndef __CVC4__THEORY__BV__SLICER_BV_H
@@ -38,10 +38,11 @@ namespace CVC4 {
namespace theory {
namespace bv {
-typedef uint32_t TermId;
-typedef uint32_t Index;
+typedef Index TermId;
+extern const TermId UndefinedId;
+
/**
* Base
@@ -51,13 +52,23 @@ class Base {
Index d_size;
std::vector<uint32_t> d_repr;
public:
- Base(uint32_t size);
+ Base(Index size);
void sliceAt(Index index);
void sliceWith(const Base& other);
bool isCutPoint(Index index) const;
void diffCutPoints(const Base& other, Base& res) const;
bool isEmpty() const;
std::string debugPrint() const;
+ Index getBitwidth() const { return d_size; }
+ bool operator==(const Base& other) const {
+ if (other.getBitwidth() != getBitwidth())
+ return false;
+ for (unsigned i = 0; i < d_repr.size(); ++i) {
+ if (d_repr[i] != other.d_repr[i])
+ return false;
+ }
+ return true;
+ }
};
/**
@@ -72,88 +83,80 @@ struct ExtractTerm {
Index high;
Index low;
ExtractTerm(TermId i, Index h, Index l)
- : id (i)
- high(h)
+ : id (i),
+ high(h),
low(l)
{
- Assert (h >= l && id != -1);
+ Assert (h >= l && id != UndefinedId);
}
+ Index getBitwidth() const { return high - low + 1; }
};
+class UnionFind;
+
struct NormalForm {
Base base;
Decomposition decomp;
+
NormalForm(Index bitwidth)
: base(bitwidth),
decomp()
{}
+ /**
+ * Returns the term in the decomposition on which the index i
+ * falls in
+ * @param i
+ *
+ * @return
+ */
+ TermId getTerm(Index i, const UnionFind& uf) const;
};
class UnionFind {
class Node {
- TermId d_repr;
- TermId d_ch1, d_ch2;
Index d_bitwidth;
+ TermId d_ch1, d_ch2;
+ TermId d_repr;
public:
Node(Index b)
: d_bitwidth(b),
- d_ch1(-1),
- d_ch2(-1),
- d_repr(-1)
+ d_ch1(UndefinedId),
+ d_ch2(UndefinedId),
+ d_repr(UndefinedId)
{}
TermId getRepr() const { return d_repr; }
Index getBitwidth() const { return d_bitwidth; }
- bool hasChildren() const { return d_ch1 != -1 && d_ch2 != -1; }
+ bool hasChildren() const { return d_ch1 != UndefinedId && d_ch2 != UndefinedId; }
TermId getChild(Index i) const {
Assert (i < 2);
- return i == 0? ch1 : ch2;
+ return i == 0? d_ch1 : d_ch2;
}
- Index getCutPoint() const {
- Assert (d_ch1 != -1 && d_ch2 != -1);
- return getNode(d_ch1).getBitwidth();
+ Index getCutPoint(const UnionFind& uf) const {
+ Assert (d_ch1 != UndefinedId && d_ch2 != UndefinedId);
+ return uf.getNode(d_ch1).getBitwidth();
}
void setRepr(TermId id) {
- Assert (d_children.empty());
+ Assert (! hasChildren());
d_repr = id;
}
void setChildren(TermId ch1, TermId ch2) {
- Assert (d_repr == -1 && d_children.empty());
- markAsNotTopLevel(ch1);
- markAsNotTopLevel(ch2);
- d_children.push_back(ch1);
- d_children.push_back(ch2);
+ Assert (d_repr == UndefinedId && !hasChildren());
+ d_ch1 = ch1;
+ d_ch2 = ch2;
}
-
- // void setChildren(TermId ch1, TermId ch2, TermId ch3) {
- // Assert (d_repr == -1 && d_children.empty());
- // d_children.push_back(ch1);
- // d_children.push_back(ch2);
- // d_children.push_back(ch3);
- // }
-
};
+ /// map from TermId to the nodes that represent them
std::vector<Node> d_nodes;
-
+ /// a term is in this set if it is its own representative
TermSet d_representatives;
- TermSet d_topLevelTerms;
- void markAsNotTopLevel(TermId id) {
- if (d_topLevelTerms.find(id) != d_topLevelTerms.end())
- d_topLevelTerms.erase(id);
- }
-
- bool isTopLevel(TermId id) {
- return d_topLevelTerms.find(id) != d_topLevelTerms.end();
- }
- Index getBitwidth(TermId id) {
- Assert (id < d_nodes.size());
- return d_nodes[id].getBitwidth();
- }
+ void getDecomposition(const ExtractTerm& term, Decomposition& decomp);
+ void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common);
public:
UnionFind()
@@ -162,32 +165,44 @@ public:
{}
TermId addTerm(Index bitwidth);
+ void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2);
void merge(TermId t1, TermId t2);
- TermId find(TermId t1);
- TermId split(TermId term, Index i);
-
- void getNormalForm(ExtractTerm term, NormalForm& nf);
- void alignSlicings(NormalForm& nf1, NormalForm& nf2);
+ TermId find(TermId t1) const ;
+ void split(TermId term, Index i);
- Node getNode(TermId id) {
+ void getNormalForm(const ExtractTerm& term, NormalForm& nf);
+ void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2);
+ void ensureSlicing(const ExtractTerm& term);
+
+ Node getNode(TermId id) const {
Assert (id < d_nodes.size());
return d_nodes[id];
}
+ Index getBitwidth(TermId id) const {
+ Assert (id < d_nodes.size());
+ return d_nodes[id].getBitwidth();
+ }
+
};
class Slicer {
__gnu_cxx::hash_map<TermId, TNode> d_idToNode;
- __gnu_cxx::hash_map<TNode, TermId> d_nodeToId;
- UnionFind d_unionFind();
-
+ __gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction> d_nodeToId;
+ __gnu_cxx::hash_map<TNode, bool, TNodeHashFunction> d_coreTermCache;
+ UnionFind d_unionFind;
+ ExtractTerm registerTerm(TNode node);
public:
Slicer()
- : d_topLevelTerms(),
+ : d_idToNode(),
+ d_nodeToId(),
+ d_coreTermCache(),
d_unionFind()
{}
- void getBaseDecomposition(TNode node, std::vector<Node>& decomp) const;
+ void getBaseDecomposition(TNode node, std::vector<Node>& decomp);
void processEquality(TNode eq);
+ bool isCoreTerm (TNode node);
+ static void splitEqualities(TNode node, std::vector<Node>& equalities);
};
}/* CVC4::theory::bv namespace */
diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp
index 1c746fafa..64662d5c6 100644
--- a/src/theory/bv/theory_bv.cpp
+++ b/src/theory/bv/theory_bv.cpp
@@ -210,7 +210,6 @@ Node TheoryBV::ppRewrite(TNode t)
void TheoryBV::presolve() {
Debug("bitvector") << "TheoryBV::presolve" << endl;
- d_slicer.computeCoarsestBase();
}
bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory)
diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h
index 73a13e1ad..f87163e37 100644
--- a/src/theory/bv/theory_bv_utils.h
+++ b/src/theory/bv/theory_bv_utils.h
@@ -418,6 +418,35 @@ inline std::string vectorToString(const std::vector<Node>& nodes) {
return out.str();
}
+// FIXME: dumb code
+inline void intersect(const std::vector<uint32_t>& v1,
+ const std::vector<uint32_t>& v2,
+ std::vector<uint32_t>& intersection) {
+ for (unsigned i = 0; i < v1.size(); ++i) {
+ bool found = false;
+ for (unsigned j = 0; j < v2.size(); ++j) {
+ if (v2[j] == v1[i]) {
+ found = true;
+ break;
+ }
+ }
+ if (found) {
+ intersection.push_back(v1[i]);
+ }
+ }
+}
+
+template <class T>
+inline T gcd(T a, T b) {
+ while (b != 0) {
+ T t = b;
+ b = a % t;
+ a = t;
+ }
+ return a;
+}
+
+
}
}
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback