summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlianah <lianahady@gmail.com>2013-03-21 19:25:33 -0400
committerlianah <lianahady@gmail.com>2013-03-21 19:25:33 -0400
commitff8572914d73449b26edba214ad134c596196e32 (patch)
tree8a3d70b2d1b4c703edc9757b2d4417ff6c49e393
parent43ed2d4e9575232655db7df249ba9be1fc9eba61 (diff)
fixed more equality stuff
-rw-r--r--src/theory/bv/bv_subtheory_core.cpp22
-rw-r--r--src/theory/bv/bv_subtheory_core.h4
-rw-r--r--src/theory/bv/slicer.cpp77
-rw-r--r--src/theory/bv/slicer.h6
4 files changed, 67 insertions, 42 deletions
diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp
index d7dab10f9..2af0e47b8 100644
--- a/src/theory/bv/bv_subtheory_core.cpp
+++ b/src/theory/bv/bv_subtheory_core.cpp
@@ -72,6 +72,9 @@ CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv)
}
}
+CoreSolver::~CoreSolver() {
+ delete d_slicer;
+}
void CoreSolver::setMasterEqualityEngine(eq::EqualityEngine* eq) {
d_equalityEngine.setMasterEqualityEngine(eq);
}
@@ -99,10 +102,11 @@ void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) {
}
}
-Node CoreSolver::getBaseDecomposition(TNode a, std::vector<TNode>& explanation) {
+Node CoreSolver::getBaseDecomposition(TNode a, std::vector<Node>& explanation) {
std::vector<Node> a_decomp;
d_slicer->getBaseDecomposition(a, a_decomp, explanation);
Node new_a = utils::mkConcat(a_decomp);
+ Debug("bv-slicer") << "CoreSolver::getBaseDecomposition " << a <<" => " << new_a << "\n";
return new_a;
}
@@ -118,7 +122,7 @@ bool CoreSolver::decomposeFact(TNode fact) {
TNode b = fact[1];
d_slicer->processEquality(fact);
- std::vector<TNode> explanation;
+ std::vector<Node> explanation;
Node new_a = getBaseDecomposition(a, explanation);
Node new_b = getBaseDecomposition(b, explanation);
@@ -157,10 +161,20 @@ bool CoreSolver::decomposeFact(TNode fact) {
d_slicer->assertEquality(fact);
} else {
// still need to register the terms
+ d_slicer->processEquality(fact[0]);
TNode a = fact[0][0];
TNode b = fact[0][1];
- d_slicer->registerTerm(a);
- d_slicer->registerTerm(b);
+ std::vector<Node> explanation_a;
+ Node new_a = getBaseDecomposition(a, explanation_a);
+ Node reason_a = explanation_a.empty()? mkTrue() : mkAnd(explanation_a);
+ assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, a, new_a), reason_a);
+
+ std::vector<Node> explanation_b;
+ Node new_b = getBaseDecomposition(b, explanation_b);
+ Node reason_b = explanation_b.empty()? mkTrue() : mkAnd(explanation_b);
+ assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, b, new_b), reason_b);
+ d_reasons.insert(reason_a);
+ d_reasons.insert(reason_b);
}
// finally assert the actual fact to the equality engine
return assertFactToEqualityEngine(fact, fact);
diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h
index f37cf5bf3..4f2d7a279 100644
--- a/src/theory/bv/bv_subtheory_core.h
+++ b/src/theory/bv/bv_subtheory_core.h
@@ -67,9 +67,10 @@ class CoreSolver : public SubtheorySolver {
context::CDHashSet<Node, NodeHashFunction> d_reasons;
bool assertFactToEqualityEngine(TNode fact, TNode reason);
bool decomposeFact(TNode fact);
- Node getBaseDecomposition(TNode a, std::vector<TNode>& explanation);
+ Node getBaseDecomposition(TNode a, std::vector<Node>& explanation);
public:
CoreSolver(context::Context* c, TheoryBV* bv);
+ ~CoreSolver();
bool isCoreTheory() { return d_isCoreTheory; }
void setMasterEqualityEngine(eq::EqualityEngine* eq);
void preRegister(TNode node);
@@ -91,6 +92,7 @@ public:
return EQUALITY_UNKNOWN;
}
bool hasTerm(TNode node) const { return d_equalityEngine.hasTerm(node); }
+ void addTermToEqualityEngine(TNode node) { d_equalityEngine.addTerm(node); }
};
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp
index 437be9bf4..5d376ea50 100644
--- a/src/theory/bv/slicer.cpp
+++ b/src/theory/bv/slicer.cpp
@@ -41,8 +41,11 @@ Base::Base(uint32_t size)
void Base::sliceAt(Index index) {
+ if (index == d_size)
+ return;
+ Assert(index < d_size);
Index vector_index = index / 32;
- Assert (vector_index < d_size);
+ Assert (vector_index < d_repr.size());
Index int_index = index % 32;
uint32_t bit_mask = utils::pow2(int_index);
d_repr[vector_index] = d_repr[vector_index] | bit_mask;
@@ -184,6 +187,15 @@ TermId UnionFind::addNode(Index bitwidth) {
}
TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) {
+ if (isExtractTerm(topLevel)) {
+ ExtractTerm top = getExtractTerm(topLevel);
+ Index top_high = top.high;
+ Index top_low = top.low;
+ Assert (top_high - top_low + 1 > high);
+ high += top_low;
+ low += top_low;
+ topLevel = top.id;
+ }
ExtractTerm extract(topLevel, high, low);
if (d_extractToId.find(extract) != d_extractToId.end()) {
return d_extractToId[extract];
@@ -292,13 +304,13 @@ void UnionFind::split(TermId id, Index i) {
Assert (i < getBitwidth(id));
if (!hasChildren(id)) {
// first time we split this term
- TermId bottom_id = addExtract(getTopLevel(id), i - 1, 0);
- TermId top_id = addExtract(getTopLevel(id), getBitwidth(id) - 1, i);
+ TermId bottom_id = addExtract(id, i - 1, 0);
+ TermId top_id = addExtract(id, getBitwidth(id) - 1, i);
setChildren(id, top_id, bottom_id);
recordOperation(UnionFind::SPLIT, id);
if (d_slicer->termInEqualityEngine(id)) {
- d_slicer->enqueueSplit(id, i);
+ d_slicer->enqueueSplit(id, i, top_id, bottom_id);
}
} else {
Index cut = getCutPoint(id);
@@ -310,13 +322,13 @@ void UnionFind::split(TermId id, Index i) {
++(d_statistics.d_numSplits);
}
-TermId UnionFind::getTopLevel(TermId id) const {
- __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> >::const_iterator it = d_idToExtract.find(id);
- if (it != d_idToExtract.end()) {
- return (*it).second.id;
- }
- return id;
-}
+// TermId UnionFind::getTopLevel(TermId id) const {
+// __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> >::const_iterator it = d_idToExtract.find(id);
+// if (it != d_idToExtract.end()) {
+// return (*it).second.id;
+// }
+// return id;
+// }
void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) {
nf.clear();
@@ -576,7 +588,7 @@ ExtractTerm Slicer::registerTerm(TNode node) {
if (node.getKind() == kind::BITVECTOR_EXTRACT) {
n = node[0];
high = utils::getExtractHigh(node);
- low = utils::getExtractLow(node);
+ low = utils::getExtractLow(node);
}
if (d_nodeToId.find(n) == d_nodeToId.end()) {
TermId id = d_unionFind.addNode(utils::getSize(n));
@@ -584,6 +596,7 @@ ExtractTerm Slicer::registerTerm(TNode node) {
d_idToNode[id] = n;
}
TermId id = d_nodeToId[n];
+ d_unionFind.addExtract(id, high, low);
ExtractTerm res(id, high, low);
Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl;
return res;
@@ -631,7 +644,7 @@ void Slicer::registerEquality(TNode eq) {
}
}
-void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation) {
+void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation) {
Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl;
Index high = utils::getSize(node) - 1;
@@ -655,16 +668,8 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::ve
explanation.push_back(exp);
}
- // construct actual extract nodes
- Index size = utils::getSize(node);
- Index current_low = size - 1;
- Index current_high = size - 1;
-
for (int i = nf.decomp.size() - 1; i>=0 ; --i) {
- Index current_size = d_unionFind.getBitwidth(nf.decomp[i]);
- current_low = current_low - current_size;
- Node current = Rewriter::rewrite(utils::mkExtract(node, current_high, current_low+1));
- current_high -= current_size;
+ Node current = getNode(nf.decomp[i]);
decomp.push_back(current);
}
@@ -763,17 +768,16 @@ bool Slicer::hasNode(TermId id) const {
}
Node Slicer::getNode(TermId id) const {
- // if it was an extract
- if (d_unionFind.isExtractTerm(id)) {
- ExtractTerm extract = d_unionFind.getExtractTerm(id);
- Assert (hasNode(extract.id));
- TNode node = d_idToNode.find(extract.id)->second;
- Node ex = utils::mkExtract(node, extract.high, extract.low);
- return ex;
+ if (hasNode(id)) {
+ return d_idToNode.find(id)->second;
}
- // otherwise must be a top-level term
- Assert (hasNode(id));
- return (d_idToNode.find(id))->second;
+ // otherwise must be an extract
+ Assert (d_unionFind.isExtractTerm(id));
+ ExtractTerm extract = d_unionFind.getExtractTerm(id);
+ Assert (hasNode(extract.id));
+ TNode node = d_idToNode.find(extract.id)->second;
+ Node ex = utils::mkExtract(node, extract.high, extract.low);
+ return ex;
}
bool Slicer::termInEqualityEngine(TermId id) {
@@ -781,13 +785,18 @@ bool Slicer::termInEqualityEngine(TermId id) {
return d_coreSolver->hasTerm(node);
}
-void Slicer::enqueueSplit(TermId id, Index i) {
+void Slicer::enqueueSplit(TermId id, Index i, TermId top_id, TermId bottom_id) {
Node node = getNode(id);
Node bottom = Rewriter::rewrite(utils::mkExtract(node, i -1 , 0));
Node top = Rewriter::rewrite(utils::mkExtract(node, utils::getSize(node) - 1, i));
+ // must add terms to equality engine so we get notified when they get split more
+ d_coreSolver->addTermToEqualityEngine(bottom);
+ d_coreSolver->addTermToEqualityEngine(top);
+
Node eq = utils::mkNode(kind::EQUAL, node, utils::mkConcat(top, bottom));
d_newSplits.push_back(eq);
- Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl;
+ Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl;
+ Debug("bv-slicer") << " " << id << "=" << top_id << " " << bottom_id << endl;
}
void Slicer::getNewSplits(std::vector<Node>& splits) {
diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h
index f63cf7284..ab2d5e88f 100644
--- a/src/theory/bv/slicer.h
+++ b/src/theory/bv/slicer.h
@@ -224,7 +224,7 @@ class UnionFind : public context::ContextNotifyObj {
Assert (id < d_nodes.size());
return d_nodes[id].hasChildren();
}
- TermId getTopLevel(TermId id) const;
+ // TermId getTopLevel(TermId id) const;
/// setter methods for the internal nodes
void setRepr(TermId id, TermId new_repr, ExplanationId reason) {
@@ -338,7 +338,7 @@ public:
d_coreSolver(coreSolver)
{}
- void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation);
+ void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation);
void registerEquality(TNode eq);
ExtractTerm registerTerm(TNode node);
void processEquality(TNode eq);
@@ -354,7 +354,7 @@ public:
ExplanationId getExplanationId(TNode reason) const;
bool termInEqualityEngine(TermId id);
- void enqueueSplit(TermId id, Index i);
+ void enqueueSplit(TermId id, Index i, TermId top, TermId bottom);
void getNewSplits(std::vector<Node>& splits);
static void splitEqualities(TNode node, std::vector<Node>& equalities);
static unsigned d_numAddedEqualities;
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback