summaryrefslogtreecommitdiff
path: root/src/theory/bv/slicer.cpp
diff options
context:
space:
mode:
authorlianah <lianahady@gmail.com>2013-03-16 15:48:51 -0400
committerlianah <lianahady@gmail.com>2013-03-16 15:48:51 -0400
commit25ac2c8f4b45e2b299895e97a30790fbf46cf79f (patch)
treed7b52003d7157073be554bd9818230f1c3b439d3 /src/theory/bv/slicer.cpp
parent3fcdb18fe92e5213aa708285c0d7d5e55633492b (diff)
started work on the inequality bv subtheory
Diffstat (limited to 'src/theory/bv/slicer.cpp')
-rw-r--r--src/theory/bv/slicer.cpp250
1 files changed, 208 insertions, 42 deletions
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp
index ac668ab20..92166224b 100644
--- a/src/theory/bv/slicer.cpp
+++ b/src/theory/bv/slicer.cpp
@@ -19,7 +19,7 @@
#include "theory/bv/slicer.h"
#include "theory/bv/theory_bv_utils.h"
#include "theory/rewriter.h"
-
+#include "theory/bv/bv_subtheory_core.h"
using namespace CVC4;
using namespace CVC4::theory;
using namespace CVC4::theory::bv;
@@ -159,7 +159,7 @@ std::string NormalForm::debugPrint(const UnionFind& uf) const {
std::string UnionFind::Node::debugPrint() const {
ostringstream os;
- os << "Repr " << d_repr << " ["<< d_bitwidth << "] ";
+ os << "Repr " << d_edge.repr << " ["<< d_bitwidth << "] ";
os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl;
return os.str();
}
@@ -169,27 +169,44 @@ std::string UnionFind::Node::debugPrint() const {
* UnionFind
*
*/
-TermId UnionFind::addTerm(Index bitwidth) {
+TermId UnionFind::addNode(Index bitwidth) {
+ Assert (bitwidth > 0);
Node node(bitwidth);
d_nodes.push_back(node);
+
++(d_statistics.d_numNodes);
TermId id = d_nodes.size() - 1;
// d_representatives.insert(id);
++(d_statistics.d_numRepresentatives);
-
Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl;
return id;
}
+
+TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) {
+ ExtractTerm extract(topLevel, high, low);
+ if (d_extractToId.find(extract) != d_extractToId.end()) {
+ return d_extractToId[extract];
+ }
+
+ Assert (high >= low);
+
+ TermId id = addNode(high - low + 1);
+ d_idToExtract[id] = extract;
+ d_extractToId[extract] = id;
+ return id;
+}
+
/**
* At this point we assume the slicings of the two terms are properly aligned.
*
* @param t1
* @param t2
*/
-void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2) {
+void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason) {
Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n"
- << " " << t2.debugPrint() << endl;
+ << " " << t2.debugPrint() << "\n"
+ << " with reason " << reason << endl;
Assert (t1.getBitwidth() == t2.getBitwidth());
NormalForm nf1(t1.getBitwidth());
@@ -202,10 +219,11 @@ void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2) {
Assert (nf1.base == nf2.base);
for (unsigned i = 0; i < nf1.decomp.size(); ++i) {
- merge (nf1.decomp[i], nf2.decomp[i]);
+ merge (nf1.decomp[i], nf2.decomp[i], reason);
}
}
+
/**
* Merge the two terms in the union find. Both t1 and t2
* should be root terms.
@@ -213,7 +231,7 @@ void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2) {
* @param t1
* @param t2
*/
-void UnionFind::merge(TermId t1, TermId t2) {
+void UnionFind::merge(TermId t1, TermId t2, TermId reason) {
Debug("bv-slicer-uf") << "UnionFind::merge (" << t1 <<", " << t2 << ")" << endl;
++(d_statistics.d_numMerges);
t1 = find(t1);
@@ -223,7 +241,7 @@ void UnionFind::merge(TermId t1, TermId t2) {
return;
Assert (! hasChildren(t1) && ! hasChildren(t2));
- setRepr(t1, t2);
+ setRepr(t1, t2, reason);
recordOperation(UnionFind::MERGE, t1);
//d_representatives.erase(t1);
d_statistics.d_numRepresentatives += -1;
@@ -233,11 +251,26 @@ TermId UnionFind::find(TermId id) {
TermId repr = getRepr(id);
if (repr != UndefinedId) {
TermId find_id = find(repr);
- // setRepr(id, find_id);
return find_id;
}
return id;
}
+
+TermId UnionFind::findWithExplanation(TermId id, std::vector<ExplanationId>& explanation) {
+ TermId repr = getRepr(id);
+
+ if (repr != UndefinedId) {
+ TermId reason = getReason(id);
+ Assert (reason != UndefinedId);
+ explanation.push_back(reason);
+
+ TermId find_id = findWithExplanation(repr, explanation);
+ return find_id;
+ }
+ return id;
+}
+
+
/**
* Splits the representative of the term between i-1 and i
*
@@ -259,10 +292,14 @@ void UnionFind::split(TermId id, Index i) {
Assert (i < getBitwidth(id));
if (!hasChildren(id)) {
// first time we split this term
- TermId bottom_id = addTerm(i);
- TermId top_id = addTerm(getBitwidth(id) - i);
+ TermId bottom_id = addExtract(getTopLevel(id), i - 1, 0);
+ TermId top_id = addExtract(getTopLevel(id), getBitwidth(id) - 1, i);
setChildren(id, top_id, bottom_id);
- recordOperation(UnionFind::SPLIT, id);
+ recordOperation(UnionFind::SPLIT, id);
+
+ if (d_slicer->termInEqualityEngine(id)) {
+ d_slicer->enqueueSplit(id, i);
+ }
} else {
Index cut = getCutPoint(id);
if (i < cut )
@@ -273,6 +310,14 @@ void UnionFind::split(TermId id, Index i) {
++(d_statistics.d_numSplits);
}
+TermId UnionFind::getTopLevel(TermId id) const {
+ __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> >::const_iterator it = d_idToExtract.find(id);
+ if (it != d_idToExtract.end()) {
+ return (*it).second.id;
+ }
+ return id;
+}
+
void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) {
nf.clear();
getDecomposition(term, nf.decomp);
@@ -319,6 +364,56 @@ void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp)
getDecomposition(high_child, decomp);
}
}
+
+void UnionFind::getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf,
+ std::vector<ExplanationId>& explanation) {
+ nf.clear();
+ getDecompositionWithExplanation(term, nf.decomp, explanation);
+ // update nf base
+ Index count = 0;
+ for (unsigned i = 0; i < nf.decomp.size(); ++i) {
+ count += getBitwidth(nf.decomp[i]);
+ nf.base.sliceAt(count);
+ }
+ Debug("bv-slicer-uf") << "UnionFind::getNormalFrom term: " << term.debugPrint() << endl;
+ Debug("bv-slicer-uf") << " nf: " << nf.debugPrint(*this) << endl;
+}
+
+void UnionFind::getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp,
+ std::vector<ExplanationId>& explanation) {
+ // making sure the term is aligned
+ TermId id = findWithExplanation(term.id, explanation);
+
+ Assert (term.high < getBitwidth(id));
+ // because we split the node, this must be the whole extract
+ if (!hasChildren(id)) {
+ Assert (term.high == getBitwidth(id) - 1 &&
+ term.low == 0);
+ decomp.push_back(id);
+ return;
+ }
+
+ Index cut = getCutPoint(id);
+
+ if (term.low < cut && term.high < cut) {
+ // the extract falls entirely on the low child
+ ExtractTerm child_ex(getChild(id, 0), term.high, term.low);
+ getDecompositionWithExplanation(child_ex, decomp, explanation);
+ }
+ else if (term.low >= cut && term.high >= cut){
+ // the extract falls entirely on the high child
+ ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut);
+ getDecompositionWithExplanation(child_ex, decomp, explanation);
+ }
+ else {
+ // the extract is split over the two children
+ ExtractTerm low_child(getChild(id, 0), cut - 1, term.low);
+ getDecompositionWithExplanation(low_child, decomp, explanation);
+ ExtractTerm high_child(getChild(id, 1), term.high - cut, 0);
+ getDecompositionWithExplanation(high_child, decomp, explanation);
+ }
+}
+
/**
* May cause reslicings of the decompositions. Must not assume the decompositons
* are the current normal form.
@@ -428,7 +523,7 @@ void UnionFind::ensureSlicing(const ExtractTerm& term) {
void UnionFind::backtrack() {
return;
int size = d_undoStack.size();
- for (int i = size; i > d_undoStackIndex.get(); --i) {
+ for (int i = size; i > (int)d_undoStackIndex.get(); --i) {
Operation op = d_undoStack.back();
Assert (!d_undoStack.empty());
d_undoStack.pop_back();
@@ -443,8 +538,8 @@ void UnionFind::backtrack() {
void UnionFind::undoMerge(TermId id) {
TermId repr = getRepr(id);
- Assert (repr != id);
- setRepr(id, UndefinedId);
+ Assert (repr != UndefinedId);
+ setRepr(id, UndefinedId, UndefinedId);
}
void UnionFind::undoSplit(TermId id) {
@@ -453,9 +548,6 @@ void UnionFind::undoSplit(TermId id) {
}
void UnionFind::recordOperation(OperationKind op, TermId term) {
- if (op == SPLIT) {
- d_newSplit = true;
- }
d_undoStackIndex.set(d_undoStackIndex.get() + 1);
d_undoStack.push_back(Operation(op, term));
Assert (d_undoStack.size() == d_undoStackIndex);
@@ -488,7 +580,7 @@ ExtractTerm Slicer::registerTerm(TNode node) {
low = utils::getExtractLow(node);
}
if (d_nodeToId.find(n) == d_nodeToId.end()) {
- TermId id = d_unionFind.addTerm(utils::getSize(n));
+ TermId id = d_unionFind.addNode(utils::getSize(n));
d_nodeToId[n] = id;
d_idToNode[id] = n;
}
@@ -500,7 +592,8 @@ ExtractTerm Slicer::registerTerm(TNode node) {
void Slicer::processEquality(TNode eq) {
Debug("bv-slicer") << "Slicer::processEquality: " << eq << endl;
-
+
+ registerEquality(eq);
Assert (eq.getKind() == kind::EQUAL);
TNode a = eq[0];
TNode b = eq[1];
@@ -511,13 +604,35 @@ void Slicer::processEquality(TNode eq) {
d_unionFind.ensureSlicing(b_ex);
d_unionFind.alignSlicings(a_ex, b_ex);
- d_unionFind.unionTerms(a_ex, b_ex);
+
Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl;
Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl;
Debug("bv-slicer") << "Slicer::processEquality done. " << endl;
}
-void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) {
+void Slicer::assertEquality(TNode eq) {
+ Assert (eq.getKind() == kind::EQUAL);
+ ExtractTerm a = registerTerm(eq[0]);
+ ExtractTerm b = registerTerm(eq[1]);
+ ExplanationId reason = getExplanationId(eq);
+ d_unionFind.unionTerms(a, b, reason);
+}
+
+TermId Slicer::getId(TNode node) const {
+ __gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction >::const_iterator it = d_nodeToId.find(node);
+ Assert (it != d_nodeToId.end());
+ return it->second;
+}
+
+void Slicer::registerEquality(TNode eq) {
+ if (d_explanationToId.find(eq) == d_explanationToId.end()) {
+ ExplanationId id = d_explanations.size();
+ d_explanations.push_back(eq);
+ d_explanationToId[eq] = id;
+ }
+}
+
+void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation) {
Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl;
Index high = utils::getSize(node) - 1;
@@ -528,10 +643,18 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) {
low = utils::getExtractLow(node);
top = node[0];
}
+
Assert (d_nodeToId.find(top) != d_nodeToId.end());
TermId id = d_nodeToId[top];
- NormalForm nf(high-low+1);
- d_unionFind.getNormalForm(ExtractTerm(id, high, low), nf);
+ NormalForm nf(high-low+1);
+ std::vector<ExplanationId> explanation_ids;
+ d_unionFind.getNormalFormWithExplanation(ExtractTerm(id, high, low), nf, explanation_ids);
+
+ for (unsigned i = 0; i < explanation_ids.size(); ++i) {
+ Assert (hasExplanation(explanation_ids[i]));
+ TNode exp = getExplanation(explanation_ids[i]);
+ explanation.push_back(exp);
+ }
// construct actual extract nodes
Index current_low = 0;
@@ -622,25 +745,68 @@ void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) {
d_numAddedEqualities += equalities.size() - 1;
}
-/**
- * Returns the base decomposition of the current term.
- *
- * @param id
- *
- * @return
- */
-Base Slicer::getTopLevelBase(TNode node) {
- if (node.getKind() == kind::BITVECTOR_EXTRACT) {
- node = node[0];
+
+ExtractTerm UnionFind::getExtractTerm(TermId id) const {
+ Assert (isExtractTerm(id));
+
+ return (d_idToExtract.find(id))->second;
+}
+
+bool UnionFind::isExtractTerm(TermId id) const {
+ return d_idToExtract.find(id) != d_idToExtract.end();
+}
+
+bool Slicer::hasNode(TermId id) const {
+ return d_idToNode.find(id) != d_idToNode.end();
+}
+
+Node Slicer::getNode(TermId id) const {
+ // if it was an extract
+ if (d_unionFind.isExtractTerm(id)) {
+ ExtractTerm extract = d_unionFind.getExtractTerm(id);
+ Assert (hasNode(extract.id));
+ TNode node = d_idToNode.find(extract.id)->second;
+ Node ex = utils::mkExtract(node, extract.high, extract.low);
+ return ex;
}
- // if we haven't seen this node before it must not be sliced yet
- if (d_nodeToId.find(node) == d_nodeToId.end()) {
- return Base(utils::getSize(node));
+ // otherwise must be a top-level term
+ Assert (hasNode(id));
+ return (d_idToNode.find(id))->second;
+}
+
+bool Slicer::termInEqualityEngine(TermId id) {
+ Node node = getNode(id);
+ return d_coreSolver->hasTerm(node);
+}
+
+void Slicer::enqueueSplit(TermId id, Index i) {
+ Node node = getNode(id);
+ Node bottom = Rewriter::rewrite(utils::mkExtract(node, i -1 , 0));
+ Node top = Rewriter::rewrite(utils::mkExtract(node, utils::getSize(node) - 1, i));
+ Node eq = utils::mkNode(kind::EQUAL, node, utils::mkConcat(top, bottom));
+ d_newSplits.push_back(eq);
+ Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl;
+}
+
+void Slicer::getNewSplits(std::vector<Node>& splits) {
+ for (unsigned i = d_newSplitsIndex; i < d_newSplits.size(); ++i) {
+ splits.push_back(d_newSplits[i]);
}
- TermId id = d_nodeToId[node];
- Base base(d_unionFind.getBitwidth(id));
- d_unionFind.getBase(id, base, 0);
- return base;
+ d_newSplitsIndex = d_newSplits.size();
+}
+
+bool Slicer::hasExplanation(ExplanationId id) const {
+ return id < d_explanations.size();
+}
+
+TNode Slicer::getExplanation(ExplanationId id) const {
+ Assert(hasExplanation(id));
+ return d_explanations[id];
+}
+
+ExplanationId Slicer::getExplanationId(TNode reason) const {
+ Assert (d_explanationToId.find(reason) != d_explanationToId.end());
+ return d_explanationToId.find(reason)->second;
}
std::string UnionFind::debugPrint(TermId id) {
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback