summaryrefslogtreecommitdiff
path: root/src/theory/bv/slicer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/bv/slicer.cpp')
-rw-r--r--src/theory/bv/slicer.cpp250
1 files changed, 175 insertions, 75 deletions
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp
index 5d376ea50..b24702635 100644
--- a/src/theory/bv/slicer.cpp
+++ b/src/theory/bv/slicer.cpp
@@ -156,11 +156,11 @@ std::string NormalForm::debugPrint(const UnionFind& uf) const {
return os.str();
}
/**
- * UnionFind::Node
+ * UnionFind::EqualityNode
*
*/
-std::string UnionFind::Node::debugPrint() const {
+std::string UnionFind::EqualityNode::debugPrint() const {
ostringstream os;
os << "Repr " << d_edge.repr << " ["<< d_bitwidth << "] ";
os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl;
@@ -172,41 +172,80 @@ std::string UnionFind::Node::debugPrint() const {
* UnionFind
*
*/
-TermId UnionFind::addNode(Index bitwidth) {
+
+TermId UnionFind::registerTopLevelTerm(Index bitwidth) {
+ TermId id = mkEqualityNode(bitwidth);
+ d_topLevelIds.insert(id);
+ return id;
+}
+
+TermId UnionFind::mkEqualityNode(Index bitwidth) {
Assert (bitwidth > 0);
- Node node(bitwidth);
- d_nodes.push_back(node);
+ EqualityNode node(bitwidth);
+ d_equalityNodes.push_back(node);
++(d_statistics.d_numNodes);
- TermId id = d_nodes.size() - 1;
+ TermId id = d_equalityNodes.size() - 1;
// d_representatives.insert(id);
++(d_statistics.d_numRepresentatives);
Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl;
return id;
}
-
-TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) {
- if (isExtractTerm(topLevel)) {
- ExtractTerm top = getExtractTerm(topLevel);
- Index top_high = top.high;
- Index top_low = top.low;
- Assert (top_high - top_low + 1 > high);
- high += top_low;
- low += top_low;
- topLevel = top.id;
+/**
+ * Create an extract term making sure there are no nested extracts.
+ *
+ * @param id
+ * @param high
+ * @param low
+ *
+ * @return
+ */
+ExtractTerm UnionFind::mkExtractTerm(TermId id, Index high, Index low) {
+ if (d_topLevelIds.find(id) != d_topLevelIds.end()) {
+ return ExtractTerm(id, high, low);
}
- ExtractTerm extract(topLevel, high, low);
+ Assert (isExtractTerm(id));
+ ExtractTerm top = getExtractTerm(id);
+ Assert (d_topLevelIds.find(top.id) != d_topLevelIds.end());
+
+ Index top_high = top.high;
+ Index top_low = top.low;
+ Assert (top_high - top_low + 1 > high);
+ high += top_low;
+ low += top_low;
+ id = top.id;
+ return ExtractTerm(id, high, low);
+}
+
+/**
+ * Associate the given extract term with the given id.
+ *
+ * @param id
+ * @param extract
+ */
+void UnionFind::storeExtractTerm(TermId id, const ExtractTerm& extract) {
if (d_extractToId.find(extract) != d_extractToId.end()) {
- return d_extractToId[extract];
+ Assert (d_extractToId[extract] == id);
+ return;
}
-
- Assert (high >= low);
-
- TermId id = addNode(high - low + 1);
+ Debug("bv-slicer") << "UnionFind::storeExtract " << extract.debugPrint() << " => id" << id << "\n";
d_idToExtract[id] = extract;
d_extractToId[extract] = id;
- return id;
+ }
+
+TermId UnionFind::addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low) {
+ ExtractTerm extract(id, high, low);
+ if (d_extractToId.find(extract) != d_extractToId.end()) {
+ // if the extract already exists we don't need to make a new node
+ TermId extract_id = d_extractToId[extract];
+ Assert (extract_id < d_equalityNodes.size());
+ return extract_id;
+ }
+ // otherwise make an equality node for it and store the extract
+ TermId node_id = mkEqualityNode(bitwidth);
+ storeExtractTerm(node_id, extract);
+ return node_id;
}
/**
@@ -215,7 +254,10 @@ TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) {
* @param t1
* @param t2
*/
-void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason) {
+void UnionFind::unionTerms(TermId id1, TermId id2, TermId reason) {
+ const ExtractTerm& t1 = getExtractTerm(id1);
+ const ExtractTerm& t2 = getExtractTerm(id2);
+
Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n"
<< " " << t2.debugPrint() << "\n"
<< " with reason " << reason << endl;
@@ -294,7 +336,7 @@ TermId UnionFind::findWithExplanation(TermId id, std::vector<ExplanationId>& exp
void UnionFind::split(TermId id, Index i) {
Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl;
id = find(id);
- Debug("bv-slicer-uf") << " node: " << d_nodes[id].debugPrint() << endl;
+ Debug("bv-slicer-uf") << " node: " << d_equalityNodes[id].debugPrint() << endl;
if (i == 0 || i == getBitwidth(id)) {
// nothing to do
@@ -303,9 +345,15 @@ void UnionFind::split(TermId id, Index i) {
Assert (i < getBitwidth(id));
if (!hasChildren(id)) {
- // first time we split this term
- TermId bottom_id = addExtract(id, i - 1, 0);
- TermId top_id = addExtract(id, getBitwidth(id) - 1, i);
+ // first time we split this term
+ ExtractTerm bottom_extract = mkExtractTerm(id, i-1, 0);
+ ExtractTerm top_extract = mkExtractTerm(id, getBitwidth(id) - 1, i);
+
+ TermId bottom_id = extractHasId(bottom_extract)? getExtractId(bottom_extract) : mkEqualityNode(i);
+ TermId top_id = extractHasId(top_extract)? getExtractId(top_extract) : mkEqualityNode(getBitwidth(id) - i);
+ storeExtractTerm(bottom_id, bottom_extract);
+ storeExtractTerm(top_id, top_extract);
+
setChildren(id, top_id, bottom_id);
recordOperation(UnionFind::SPLIT, id);
@@ -471,7 +519,10 @@ void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposit
}
-void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) {
+void UnionFind::alignSlicings(TermId id1, TermId id2) {
+ const ExtractTerm& term1 = getExtractTerm(id1);
+ const ExtractTerm& term2 = getExtractTerm(id2);
+
Debug("bv-slicer") << "UnionFind::alignSlicings " << term1.debugPrint() << endl;
Debug("bv-slicer") << " " << term2.debugPrint() << endl;
NormalForm nf1(term1.getBitwidth());
@@ -519,15 +570,18 @@ void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2
}
} while (changed);
}
+
+
/**
* Given an extract term a[i:j] makes sure a is sliced
* at indices i and j.
*
* @param term
*/
-void UnionFind::ensureSlicing(const ExtractTerm& term) {
+void UnionFind::ensureSlicing(TermId t) {
+ ExtractTerm term = getExtractTerm(t);
//Debug("bv-slicer") << "Slicer::ensureSlicing " << term.debugPrint() << endl;
- TermId id = find(term.id);
+ TermId id = term.id;
split(id, term.high + 1);
split(id, term.low);
}
@@ -576,30 +630,69 @@ void UnionFind::getBase(TermId id, Base& base, Index offset) {
getBase(id0, base, offset);
}
+/// getter methods for the internal nodes
+TermId UnionFind::getRepr(TermId id) const {
+ Assert (id < d_equalityNodes.size());
+ return d_equalityNodes[id].getRepr();
+}
+ExplanationId UnionFind::getReason(TermId id) const {
+ Assert (id < d_equalityNodes.size());
+ return d_equalityNodes[id].getReason();
+}
+TermId UnionFind::getChild(TermId id, Index i) const {
+ Assert (id < d_equalityNodes.size());
+ return d_equalityNodes[id].getChild(i);
+}
+Index UnionFind::getCutPoint(TermId id) const {
+ return getBitwidth(getChild(id, 0));
+}
+bool UnionFind::hasChildren(TermId id) const {
+ Assert (id < d_equalityNodes.size());
+ return d_equalityNodes[id].hasChildren();
+}
+
+/// setter methods for the internal nodes
+void UnionFind::setRepr(TermId id, TermId new_repr, ExplanationId reason) {
+ Assert (id < d_equalityNodes.size());
+ d_equalityNodes[id].setRepr(new_repr, reason);
+}
+void UnionFind::setChildren(TermId id, TermId ch1, TermId ch0) {
+ Assert ((ch1 == UndefinedId && ch0 == UndefinedId) ||
+ (id < d_equalityNodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0)));
+ d_equalityNodes[id].setChildren(ch1, ch0);
+}
+
/**
* Slicer
*
*/
-ExtractTerm Slicer::registerTerm(TNode node) {
- Index low = 0, high = utils::getSize(node) - 1;
- TNode n = node;
+TermId Slicer::registerTerm(TNode node) {
if (node.getKind() == kind::BITVECTOR_EXTRACT) {
- n = node[0];
- high = utils::getExtractHigh(node);
- low = utils::getExtractLow(node);
- }
- if (d_nodeToId.find(n) == d_nodeToId.end()) {
- TermId id = d_unionFind.addNode(utils::getSize(n));
- d_nodeToId[n] = id;
- d_idToNode[id] = n;
+ TNode n = node[0];
+ TermId top_id = registerTopLevelTerm(n);
+ Index high = utils::getExtractHigh(node);
+ Index low = utils::getExtractLow(node);
+ TermId id = d_unionFind.addEqualityNode(utils::getSize(node), top_id, high, low);
+ return id;
+ }
+ TermId id = registerTopLevelTerm(node);
+ return id;
+}
+
+TermId Slicer::registerTopLevelTerm(TNode node) {
+ Assert (node.getKind() != kind::BITVECTOR_EXTRACT ||
+ node.getKind() != kind::BITVECTOR_CONCAT);
+
+ if (d_nodeToId.find(node) == d_nodeToId.end()) {
+ TermId id = d_unionFind.registerTopLevelTerm(utils::getSize(node));
+ d_idToNode[id] = node;
+ d_nodeToId[node] = id;
+ Debug("bv-slicer") << "Slicer::registerTopLevelTerm " << node << " => id" << id << endl;
+ return id;
}
- TermId id = d_nodeToId[n];
- d_unionFind.addExtract(id, high, low);
- ExtractTerm res(id, high, low);
- Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl;
- return res;
+ return d_nodeToId[node];
}
void Slicer::processEquality(TNode eq) {
@@ -609,42 +702,38 @@ void Slicer::processEquality(TNode eq) {
Assert (eq.getKind() == kind::EQUAL);
TNode a = eq[0];
TNode b = eq[1];
- ExtractTerm a_ex= registerTerm(a);
- ExtractTerm b_ex= registerTerm(b);
+ TermId a_id = registerTerm(a);
+ TermId b_id = registerTerm(b);
- d_unionFind.ensureSlicing(a_ex);
- d_unionFind.ensureSlicing(b_ex);
+ d_unionFind.ensureSlicing(a_id);
+ d_unionFind.ensureSlicing(b_id);
- d_unionFind.alignSlicings(a_ex, b_ex);
+ d_unionFind.alignSlicings(a_id, b_id);
- Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl;
- Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl;
- Debug("bv-slicer") << "Slicer::processEquality done. " << endl;
+ // Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl;
+ // Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl;
+ // Debug("bv-slicer") << "Slicer::processEquality done. " << endl;
}
void Slicer::assertEquality(TNode eq) {
Assert (eq.getKind() == kind::EQUAL);
- ExtractTerm a = registerTerm(eq[0]);
- ExtractTerm b = registerTerm(eq[1]);
+ TermId a = registerTerm(eq[0]);
+ TermId b = registerTerm(eq[1]);
ExplanationId reason = getExplanationId(eq);
d_unionFind.unionTerms(a, b, reason);
}
-TermId Slicer::getId(TNode node) const {
- __gnu_cxx::hash_map<Node, TermId, NodeHashFunction >::const_iterator it = d_nodeToId.find(node);
- Assert (it != d_nodeToId.end());
- return it->second;
-}
void Slicer::registerEquality(TNode eq) {
if (d_explanationToId.find(eq) == d_explanationToId.end()) {
ExplanationId id = d_explanations.size();
d_explanations.push_back(eq);
- d_explanationToId[eq] = id;
+ d_explanationToId[eq] = id;
+ Debug("bv-slicer-explanation") << "Slicer::registerEquality " << eq << " => id"<< id << "\n";
}
}
-void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation) {
+void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation) {
Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl;
Index high = utils::getSize(node) - 1;
@@ -672,13 +761,18 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::ve
Node current = getNode(nf.decomp[i]);
decomp.push_back(current);
}
-
-
- Debug("bv-slicer") << "as [";
- for (unsigned i = 0; i < decomp.size(); ++i) {
- Debug("bv-slicer") << decomp[i] <<" ";
+ if (Debug.isOn("bv-slicer-explanation")) {
+ Debug("bv-slicer-explanation") << "Slicer::getBaseDecomposition for " << node << "\n"
+ << "as ";
+ for (unsigned i = 0; i < decomp.size(); ++i) {
+ Debug("bv-slicer-explanation") << decomp[i] <<" " ;
+ }
+ Debug("bv-slicer-explanation") << "\n Explanation : \n";
+ for (unsigned i = 0; i < explanation.size(); ++i) {
+ Debug("bv-slicer-explanation") << " " << explanation[i] << "\n";
+ }
+
}
- Debug("bv-slicer") << "]" << endl;
}
@@ -754,6 +848,10 @@ void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) {
ExtractTerm UnionFind::getExtractTerm(TermId id) const {
+ if (d_topLevelIds.find(id) != d_topLevelIds.end()) {
+ // if it's a top level term so we don't have an extract stored for it
+ return ExtractTerm(id, getBitwidth(id) - 1, 0);
+ }
Assert (isExtractTerm(id));
return (d_idToExtract.find(id))->second;
@@ -763,19 +861,21 @@ bool UnionFind::isExtractTerm(TermId id) const {
return d_idToExtract.find(id) != d_idToExtract.end();
}
-bool Slicer::hasNode(TermId id) const {
+bool Slicer::isTopLevelNode(TermId id) const {
return d_idToNode.find(id) != d_idToNode.end();
}
Node Slicer::getNode(TermId id) const {
- if (hasNode(id)) {
+ if (isTopLevelNode(id)) {
return d_idToNode.find(id)->second;
}
- // otherwise must be an extract
Assert (d_unionFind.isExtractTerm(id));
- ExtractTerm extract = d_unionFind.getExtractTerm(id);
- Assert (hasNode(extract.id));
+ const ExtractTerm& extract = d_unionFind.getExtractTerm(id);
+ Assert (isTopLevelNode(extract.id));
TNode node = d_idToNode.find(extract.id)->second;
+ if (extract.high == utils::getSize(node) -1 && extract.low == 0) {
+ return node;
+ }
Node ex = utils::mkExtract(node, extract.high, extract.low);
return ex;
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback