summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/theory/bv/slicer.cpp50
-rw-r--r--src/theory/bv/slicer.h43
2 files changed, 52 insertions, 41 deletions
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp
index c624b9c5e..80a52525d 100644
--- a/src/theory/bv/slicer.cpp
+++ b/src/theory/bv/slicer.cpp
@@ -207,17 +207,14 @@ void UnionFind::merge(TermId t1, TermId t2) {
if (t1 == t2)
return;
- Node n1 = getNode(t1);
- Node n2 = getNode(t2);
- Assert (! n1.hasChildren() && ! n2.hasChildren());
- n1.setRepr(t2);
+ Assert (! hasChildren(t1) && ! hasChildren(t2));
+ setRepr(t1, t2);
d_representatives.erase(t1);
}
TermId UnionFind::find(TermId id) const {
- Node node = getNode(id);
- if (node.getRepr() != UndefinedId)
- return find(node.getRepr());
+ if (getRepr(id) != UndefinedId)
+ return find(getRepr(id));
return id;
}
/**
@@ -231,27 +228,25 @@ TermId UnionFind::find(TermId id) const {
void UnionFind::split(TermId id, Index i) {
Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl;
id = find(id);
- Node node = getNode(id);
- Debug("bv-slicer-uf") << " node: " << node.debugPrint() << endl;
- Assert (i < node.getBitwidth());
+ Debug("bv-slicer-uf") << " node: " << d_nodes[id].debugPrint() << endl;
- if (i == 0 || i == node.getBitwidth()) {
+ if (i == 0 || i == getBitwidth(id)) {
// nothing to do
return;
}
-
- if (!node.hasChildren()) {
+ Assert (i < getBitwidth(id));
+ if (!hasChildren(id)) {
// first time we split this term
TermId bottom_id = addTerm(i);
- TermId top_id = addTerm(node.getBitwidth() - i);
- node.setChildren(top_id, bottom_id);
+ TermId top_id = addTerm(getBitwidth(id) - i);
+ setChildren(id, top_id, bottom_id);
} else {
- Index cut = node.getCutPoint(*this);
+ Index cut = getCutPoint(id);
if (i < cut )
- split(node.getChild(0), i);
+ split(getChild(id, 1), i);
else
- split(node.getChild(1), i - cut);
+ split(getChild(id, 0), i - cut);
}
}
@@ -271,32 +266,31 @@ void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp)
// making sure the term is aligned
TermId id = find(term.id);
- Node node = getNode(id);
- Assert (term.high < node.getBitwidth());
+ Assert (term.high < getBitwidth(id));
// because we split the node, this must be the whole extract
- if (!node.hasChildren()) {
- Assert (term.high == node.getBitwidth() - 1 &&
+ if (!hasChildren(id)) {
+ Assert (term.high == getBitwidth(id) - 1 &&
term.low == 0);
decomp.push_back(id);
}
- Index cut = node.getCutPoint(*this);
+ Index cut = getCutPoint(id);
if (term.low < cut && term.high < cut) {
// the extract falls entirely on the low child
- ExtractTerm child_ex(node.getChild(0), term.high, term.low);
+ ExtractTerm child_ex(getChild(id, 0), term.high, term.low);
getDecomposition(child_ex, decomp);
}
else if (term.low >= cut && term.high >= cut){
// the extract falls entirely on the high child
- ExtractTerm child_ex(node.getChild(1), term.high - cut, term.low - cut);
+ ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut);
getDecomposition(child_ex, decomp);
}
else {
// the extract is split over the two children
- ExtractTerm low_child(node.getChild(0), cut - 1, term.low);
+ ExtractTerm low_child(getChild(id, 0), cut - 1, term.low);
getDecomposition(low_child, decomp);
- ExtractTerm high_child(node.getChild(1), term.high, cut);
+ ExtractTerm high_child(getChild(id, 1), term.high, cut);
getDecomposition(high_child, decomp);
}
}
@@ -397,7 +391,7 @@ void UnionFind::ensureSlicing(const ExtractTerm& term) {
*/
ExtractTerm Slicer::registerTerm(TNode node) {
- Index low = 0, high = utils::getSize(node);
+ Index low = 0, high = utils::getSize(node) - 1;
TNode n = node;
if (node.getKind() == kind::BITVECTOR_EXTRACT) {
n = node[0];
diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h
index c4b3b06a1..b27b85e65 100644
--- a/src/theory/bv/slicer.h
+++ b/src/theory/bv/slicer.h
@@ -119,7 +119,7 @@ class UnionFind {
class Node {
Index d_bitwidth;
TermId d_ch1, d_ch2;
- TermId d_repr;
+ TermId d_repr;
public:
Node(Index b)
: d_bitwidth(b),
@@ -136,23 +136,18 @@ class UnionFind {
Assert (i < 2);
return i == 0? d_ch1 : d_ch2;
}
- Index getCutPoint(const UnionFind& uf) const {
- Assert (d_ch1 != UndefinedId && d_ch2 != UndefinedId);
- return uf.getNode(d_ch1).getBitwidth();
- }
void setRepr(TermId id) {
Assert (! hasChildren());
d_repr = id;
}
-
void setChildren(TermId ch1, TermId ch2) {
Assert (d_repr == UndefinedId && !hasChildren());
d_ch1 = ch1;
d_ch2 = ch2;
}
- std::string debugPrint() const;
+ std::string debugPrint() const;
};
-
+
/// map from TermId to the nodes that represent them
std::vector<Node> d_nodes;
/// a term is in this set if it is its own representative
@@ -160,6 +155,32 @@ class UnionFind {
void getDecomposition(const ExtractTerm& term, Decomposition& decomp);
void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common);
+ /// getter methods for the internal nodes
+ TermId getRepr(TermId id) const {
+ Assert (id < d_nodes.size());
+ return d_nodes[id].getRepr();
+ }
+ TermId getChild(TermId id, Index i) const {
+ Assert (id < d_nodes.size());
+ return d_nodes[id].getChild(i);
+ }
+ Index getCutPoint(TermId id) const {
+ return getBitwidth(getChild(id, 1));
+ }
+ bool hasChildren(TermId id) const {
+ Assert (id < d_nodes.size());
+ return d_nodes[id].hasChildren();
+ }
+ /// setter methods for the internal nodes
+ void setRepr(TermId id, TermId new_repr) {
+ Assert (id < d_nodes.size());
+ d_nodes[id].setRepr(new_repr);
+ }
+ void setChildren(TermId id, TermId ch1, TermId ch2) {
+ Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch2));
+ d_nodes[id].setChildren(ch1, ch2);
+ }
+
public:
UnionFind()
@@ -176,11 +197,6 @@ public:
void getNormalForm(const ExtractTerm& term, NormalForm& nf);
void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2);
void ensureSlicing(const ExtractTerm& term);
-
- Node getNode(TermId id) const {
- Assert (id < d_nodes.size());
- return d_nodes[id];
- }
Index getBitwidth(TermId id) const {
Assert (id < d_nodes.size());
return d_nodes[id].getBitwidth();
@@ -208,6 +224,7 @@ public:
static void splitEqualities(TNode node, std::vector<Node>& equalities);
};
+
}/* CVC4::theory::bv namespace */
}/* CVC4::theory namespace */
}/* CVC4 namespace */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback