summaryrefslogtreecommitdiff
path: root/src/theory/bv
diff options
context:
space:
mode:
authorlianah <lianahady@gmail.com>2013-03-06 16:35:38 -0500
committerlianah <lianahady@gmail.com>2013-03-06 16:35:38 -0500
commit267ad0ceb6808bd4c05d7c4bb04a7886efc19eab (patch)
treed01a9ab188127d3277b098d8eb63d502ca4f3a10 /src/theory/bv
parente50fc148b0ae2d74e3b7b7bb86cf8a038a3d9ca4 (diff)
more slicer changes for incremental
Diffstat (limited to 'src/theory/bv')
-rw-r--r--src/theory/bv/bv_subtheory_core.cpp76
-rw-r--r--src/theory/bv/bv_subtheory_core.h5
-rw-r--r--src/theory/bv/slicer.cpp59
-rw-r--r--src/theory/bv/slicer.h21
-rw-r--r--src/theory/bv/theory_bv.cpp8
5 files changed, 130 insertions, 39 deletions
diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp
index 3f2ede9e2..91cf29ee9 100644
--- a/src/theory/bv/bv_subtheory_core.cpp
+++ b/src/theory/bv/bv_subtheory_core.cpp
@@ -100,21 +100,14 @@ void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) {
}
Node CoreSolver::getBaseDecomposition(TNode a) {
- // if (d_normalFormCache.find(a) != d_normalFormCache.end()) {
- // return d_normalFormCache[a];
- // }
-
- // otherwise we must compute the normal form
std::vector<Node> a_decomp;
d_slicer->getBaseDecomposition(a, a_decomp);
Node new_a = utils::mkConcat(a_decomp);
- // d_normalFormCache[a] = new_a;
return new_a;
}
bool CoreSolver::decomposeFact(TNode fact) {
Debug("bv-slicer") << "CoreSolver::decomposeFact fact=" << fact << endl;
- // FIXME: are this the right things to assert?
// assert decompositions since the equality engine does not know the semantics of
// concat:
// a == a_1 concat ... concat a_k
@@ -123,6 +116,12 @@ bool CoreSolver::decomposeFact(TNode fact) {
TNode a = eq[0];
TNode b = eq[1];
+ // we need to get the old decomposition to keep track of the cuts we added
+ Base a_old_base = d_slicer->getTopLevelBase(a);
+ Base b_old_base = d_slicer->getTopLevelBase(b);
+
+ d_slicer->processEquality(eq);
+
Node new_a = getBaseDecomposition(a);
Node new_b = getBaseDecomposition(b);
@@ -133,7 +132,15 @@ bool CoreSolver::decomposeFact(TNode fact) {
Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a);
Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b);
- bool ok = true;
+ Base a_new_base = d_slicer->getTopLevelBase(a);
+ Base b_new_base = d_slicer->getTopLevelBase(b);
+
+ bool ok = true;
+ ok = addNewSplits(a, a_old_base, a_new_base);
+ if (!ok) return false;
+ ok = addNewSplits(b, b_old_base, b_new_base);
+ if (!ok) return false;
+
ok = assertFact(a_eq_new_a, utils::mkTrue());
if (!ok) return false;
ok = assertFact(b_eq_new_b, utils::mkTrue());
@@ -158,6 +165,56 @@ bool CoreSolver::decomposeFact(TNode fact) {
return true;
}
+bool CoreSolver::addNewSplits(TNode n, Base& old_base, Base& new_base) {
+ if (n.getKind() == kind::BITVECTOR_EXTRACT) {
+ n = n[0];
+ }
+ Assert (old_base.getBitwidth() == new_base.getBitwidth() &&
+ utils::getSize(n) == old_base.getBitwidth());
+
+ Index high, low = 0;
+ std::vector<std::pair<Index, Index> > toSlice;
+ bool hasNewCut = false;
+ // collect the intervals that need to be sliced
+ for (unsigned i = 0; i <= old_base.getBitwidth(); ++i) {
+ Assert (! old_base.isCutPoint(i) || new_base.isCutPoint(i));
+ if (new_base.isCutPoint(i) && !old_base.isCutPoint(i)) {
+ hasNewCut = true;
+ }
+ if (new_base.isCutPoint(i) && old_base.isCutPoint(i)) {
+ high = i;
+ if (hasNewCut) {
+ toSlice.push_back(std::pair<Index, Index>(high, low));
+ }
+ low = i;
+ hasNewCut = false;
+ }
+ }
+ // for each interval, assert the proper equality
+ for (unsigned i = 0; i < toSlice.size(); ++i) {
+ int high = toSlice[i].first;
+ int low = toSlice[i].second;
+ int prev = high;
+ std::vector<Node> extracts;
+ for (int k = high -1; k >= low; --k) {
+ if (new_base.isCutPoint(k) && (!old_base.isCutPoint(k) || k == low)) {
+ // add a new extract
+ Node ex = utils::mkExtract(n, prev - 1, k);
+ prev = k;
+ extracts.push_back(ex);
+ }
+ }
+ Node concat = utils::mkConcat(extracts);
+ Node current = utils::mkExtract(n, high - 1, low);
+ Node eq = utils::mkNode(kind::EQUAL, concat, current);
+ bool ok = assertFact(eq, utils::mkTrue());
+ if (!ok)
+ return false;
+ }
+ return true;
+}
+
+
bool CoreSolver::addAssertions(const std::vector<TNode>& assertions, Theory::Effort e) {
Trace("bitvector::core") << "CoreSolver::addAssertions \n";
Assert (!d_bv->inConflict());
@@ -168,14 +225,13 @@ bool CoreSolver::addAssertions(const std::vector<TNode>& assertions, Theory::Eff
TNode fact = assertions[i];
// update whether we are in the core fragment
- // FIXME: move isCoreTerm into CoreSolver
if (d_isCoreTheory && !d_slicer->isCoreTerm(fact)) {
d_isCoreTheory = false;
}
// only reason about equalities
- // FIXME: should we slice when we have the terms in inequalities?
if (fact.getKind() == kind::EQUAL || (fact.getKind() == kind::NOT && fact[0].getKind() == kind::EQUAL)) {
+ TNode eq = fact.getKind() == kind::EQUAL ? fact : fact[0];
ok = decomposeFact(fact);
} else {
ok = assertFact(fact, fact);
diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h
index 38676bfa6..1adf813ff 100644
--- a/src/theory/bv/bv_subtheory_core.h
+++ b/src/theory/bv/bv_subtheory_core.h
@@ -25,7 +25,7 @@ namespace theory {
namespace bv {
class Slicer;
-
+class Base;
/**
* Bitvector equality solver
*/
@@ -75,7 +75,8 @@ class CoreSolver : public SubtheorySolver {
bool assertFact(TNode fact, TNode reason);
bool decomposeFact(TNode fact);
- Node getBaseDecomposition(TNode a);
+ Node getBaseDecomposition(TNode a);
+ bool addNewSplits(TNode n, Base& old_base, Base& new_base);
public:
bool isCoreTheory() {return d_isCoreTheory; }
CoreSolver(context::Context* c, TheoryBV* bv, Slicer* slicer);
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp
index 3a6ca8a2f..2334ed2b0 100644
--- a/src/theory/bv/slicer.cpp
+++ b/src/theory/bv/slicer.cpp
@@ -167,7 +167,7 @@ TermId UnionFind::addTerm(Index bitwidth) {
++(d_statistics.d_numNodes);
TermId id = d_nodes.size() - 1;
- d_representatives.insert(id);
+ // d_representatives.insert(id);
++(d_statistics.d_numRepresentatives);
Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl;
@@ -217,7 +217,7 @@ void UnionFind::merge(TermId t1, TermId t2) {
Assert (! hasChildren(t1) && ! hasChildren(t2));
setRepr(t1, t2);
recordOperation(UnionFind::MERGE, t1);
- d_representatives.erase(t1);
+ //d_representatives.erase(t1);
d_statistics.d_numRepresentatives += -1;
}
@@ -254,7 +254,6 @@ void UnionFind::split(TermId id, Index i) {
TermId top_id = addTerm(getBitwidth(id) - i);
setChildren(id, top_id, bottom_id);
recordOperation(UnionFind::SPLIT, id);
-
} else {
Index cut = getCutPoint(id);
if (i < cut )
@@ -418,8 +417,10 @@ void UnionFind::ensureSlicing(const ExtractTerm& term) {
}
void UnionFind::backtrack() {
- for (int i = d_undoStack.size() -1; i >= d_undoStackIndex; ++i) {
+ int size = d_undoStack.size();
+ for (int i = size; i > d_undoStackIndex.get(); --i) {
Operation op = d_undoStack.back();
+ Assert (!d_undoStack.empty());
d_undoStack.pop_back();
if (op.op == UnionFind::MERGE) {
undoMerge(op.id);
@@ -431,23 +432,35 @@ void UnionFind::backtrack() {
}
void UnionFind::undoMerge(TermId id) {
- Node& node = getNode(id);
- Assert (getRepr(id) != id);
- setRepr(id, id);
+ TermId repr = getRepr(id);
+ Assert (repr != id);
+ setRepr(id, UndefinedId);
}
void UnionFind::undoSplit(TermId id) {
- Node& node = getNode(id);
- Assert (hasChildren(node));
- setChildren(id, UndefindId, UndefinedId);
+ Assert (hasChildren(id));
+ setChildren(id, UndefinedId, UndefinedId);
}
void UnionFind::recordOperation(OperationKind op, TermId term) {
- ++d_undoStackIndex;
+ d_undoStackIndex.set(d_undoStackIndex.get() + 1);
d_undoStack.push_back(Operation(op, term));
Assert (d_undoStack.size() == d_undoStackIndex);
}
+void UnionFind::getBase(TermId id, Base& base, Index offset) {
+ id = find(id);
+ if (!hasChildren(id))
+ return;
+ TermId id1 = find(getChild(id, 1));
+ TermId id0 = find(getChild(id, 0));
+ Index cut = getCutPoint(id);
+ base.sliceAt(cut + offset);
+ getBase(id1, base, cut + offset);
+ getBase(id0, base, offset);
+}
+
+
/**
* Slicer
*
@@ -517,7 +530,6 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) {
current_low += current_size;
decomp.push_back(current);
}
- // cache the result
Debug("bv-slicer") << "as [";
for (unsigned i = 0; i < decomp.size(); ++i) {
@@ -595,7 +607,28 @@ void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) {
equalities.push_back(node);
}
d_numAddedEqualities += equalities.size() - 1;
-}
+}
+
+/**
+ * Returns the base decomposition of the current term.
+ *
+ * @param id
+ *
+ * @return
+ */
+Base Slicer::getTopLevelBase(TNode node) {
+ if (node.getKind() == kind::BITVECTOR_EXTRACT) {
+ node = node[0];
+ }
+ // if we haven't seen this node before it must not be sliced yet
+ if (d_nodeToId.find(node) == d_nodeToId.end()) {
+ return Base(utils::getSize(node));
+ }
+ TermId id = d_nodeToId[node];
+ Base base(d_unionFind.getBitwidth(id));
+ d_unionFind.getBase(id, base, 0);
+ return base;
+}
std::string UnionFind::debugPrint(TermId id) {
ostringstream os;
diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h
index 731141262..0508c67c1 100644
--- a/src/theory/bv/slicer.h
+++ b/src/theory/bv/slicer.h
@@ -56,7 +56,7 @@ class Base {
Index d_size;
std::vector<uint32_t> d_repr;
public:
- Base(Index size);
+ Base (Index size);
void sliceAt(Index index);
void sliceWith(const Base& other);
bool isCutPoint(Index index) const;
@@ -84,7 +84,7 @@ public:
* UnionFind
*
*/
-typedef context::CDHashSet<uint32_t> CDTermSet;
+typedef context::CDHashSet<uint32_t, std::hash<uint32_t> > CDTermSet;
typedef std::vector<TermId> Decomposition;
struct ExtractTerm {
@@ -151,7 +151,7 @@ class UnionFind : public context::ContextNotifyObj {
d_repr = id;
}
void setChildren(TermId ch1, TermId ch0) {
- Assert (d_repr == UndefinedId && !hasChildren());
+ // Assert (d_repr == UndefinedId && !hasChildren());
d_ch1 = ch1;
d_ch0 = ch0;
}
@@ -161,7 +161,7 @@ class UnionFind : public context::ContextNotifyObj {
/// map from TermId to the nodes that represent them
std::vector<Node> d_nodes;
/// a term is in this set if it is its own representative
- CDTermSet d_representatives;
+ //CDTermSet d_representatives;
void getDecomposition(const ExtractTerm& term, Decomposition& decomp);
void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common);
@@ -187,7 +187,8 @@ class UnionFind : public context::ContextNotifyObj {
d_nodes[id].setRepr(new_repr);
}
void setChildren(TermId id, TermId ch1, TermId ch0) {
- Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0));
+ Assert ((ch1 == UndefinedId && ch0 == UndefinedId) ||
+ (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0)));
d_nodes[id].setChildren(ch1, ch0);
}
@@ -212,7 +213,7 @@ class UnionFind : public context::ContextNotifyObj {
void undoMerge(TermId id);
void undoSplit(TermId id);
void recordOperation(OperationKind op, TermId term);
-
+ virtual ~UnionFind() throw(AssertionException) {}
class Statistics {
public:
IntStat d_numNodes;
@@ -228,8 +229,9 @@ class UnionFind : public context::ContextNotifyObj {
Statistics d_statistics;
public:
UnionFind(context::Context* ctx)
- : d_nodes(),
- d_representatives(ctx),
+ : ContextNotifyObj(ctx),
+ d_nodes(),
+ // d_representatives(ctx),
d_undoStack(),
d_undoStackIndex(ctx),
d_statistics()
@@ -248,6 +250,7 @@ public:
Assert (id < d_nodes.size());
return d_nodes[id].getBitwidth();
}
+ void getBase(TermId id, Base& base, Index offset);
std::string debugPrint(TermId id);
void contextNotifyPop() {
@@ -274,7 +277,7 @@ public:
void getBaseDecomposition(TNode node, std::vector<Node>& decomp);
void processEquality(TNode eq);
bool isCoreTerm (TNode node);
-
+ Base getTopLevelBase(TNode node);
static void splitEqualities(TNode node, std::vector<Node>& equalities);
static unsigned d_numAddedEqualities;
};
diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp
index bb4b480d6..6248782bd 100644
--- a/src/theory/bv/theory_bv.cpp
+++ b/src/theory/bv/theory_bv.cpp
@@ -40,7 +40,7 @@ TheoryBV::TheoryBV(context::Context* c, context::UserContext* u, OutputChannel&
d_context(c),
d_alreadyPropagatedSet(c),
d_sharedTermsSet(c),
- d_slicer(),
+ d_slicer(c),
d_bitblastAssertionsQueue(c),
d_bitblastSolver(c, this),
d_coreSolver(c, this, &d_slicer),
@@ -74,6 +74,8 @@ TheoryBV::Statistics::~Statistics() {
StatisticsRegistry::unregisterStat(&d_solveTimer);
}
+
+
void TheoryBV::preRegisterTerm(TNode node) {
Debug("bitvector-preregister") << "TheoryBV::preRegister(" << node << ")" << std::endl;
@@ -81,10 +83,6 @@ void TheoryBV::preRegisterTerm(TNode node) {
// don't use the equality engine in the eager bit-blasting
return;
}
-
- if (node.getKind() == kind::EQUAL) {
- d_slicer.processEquality(node);
- }
d_bitblastSolver.preRegister(node);
d_coreSolver.preRegister(node);
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback