summaryrefslogtreecommitdiff
path: root/src/theory
diff options
context:
space:
mode:
authorlianah <lianahady@gmail.com>2013-03-23 13:40:29 -0400
committerlianah <lianahady@gmail.com>2013-03-23 13:40:29 -0400
commit8882aef2dd4f1f629b0de99fc3a7f390fab2f83e (patch)
treed27049c6fd5be1332f5b7ae9c854985ffee683e4 /src/theory
parent73bc28dd03f68c2c1b8510f3200c3950622e0295 (diff)
fixed some explanation problems for the core theory; still slow
Diffstat (limited to 'src/theory')
-rw-r--r--src/theory/bv/bv_subtheory_core.cpp40
-rw-r--r--src/theory/bv/bv_subtheory_core.h2
-rw-r--r--src/theory/bv/slicer.cpp250
-rw-r--r--src/theory/bv/slicer.h111
-rw-r--r--src/theory/bv/theory_bv_utils.h128
5 files changed, 330 insertions, 201 deletions
diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp
index 2af0e47b8..6f5fd4119 100644
--- a/src/theory/bv/bv_subtheory_core.cpp
+++ b/src/theory/bv/bv_subtheory_core.cpp
@@ -102,7 +102,7 @@ void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) {
}
}
-Node CoreSolver::getBaseDecomposition(TNode a, std::vector<Node>& explanation) {
+Node CoreSolver::getBaseDecomposition(TNode a, std::vector<TNode>& explanation) {
std::vector<Node> a_decomp;
d_slicer->getBaseDecomposition(a, a_decomp, explanation);
Node new_a = utils::mkConcat(a_decomp);
@@ -122,28 +122,35 @@ bool CoreSolver::decomposeFact(TNode fact) {
TNode b = fact[1];
d_slicer->processEquality(fact);
- std::vector<Node> explanation;
- Node new_a = getBaseDecomposition(a, explanation);
- Node new_b = getBaseDecomposition(b, explanation);
+ std::vector<TNode> explanation_a;
+ Node new_a = getBaseDecomposition(a, explanation_a);
+ Node reason_a = mkAnd(explanation_a);
+ d_reasons.insert(reason_a);
+
+ std::vector<TNode> explanation_b;
+ Node new_b = getBaseDecomposition(b, explanation_b);
+ Node reason_b = mkAnd(explanation_b);
+ d_reasons.insert(reason_b);
+ std::vector<Node> explanation;
explanation.push_back(fact);
+ explanation.insert(explanation.end(), explanation_a.begin(), explanation_a.end());
+ explanation.insert(explanation.end(), explanation_b.begin(), explanation_b.end());
+
Node reason = utils::mkAnd(explanation);
d_reasons.insert(reason);
Assert (utils::getSize(new_a) == utils::getSize(new_b) &&
utils::getSize(new_a) == utils::getSize(a));
- // FIXME: do we still need to assert these?
+
NodeManager* nm = NodeManager::currentNM();
Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a);
Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b);
- d_reasons.insert(a_eq_new_a);
- d_reasons.insert(b_eq_new_b);
-
bool ok = true;
- ok = assertFactToEqualityEngine(a_eq_new_a, utils::mkTrue());
+ ok = assertFactToEqualityEngine(a_eq_new_a, reason_a);
if (!ok) return false;
- ok = assertFactToEqualityEngine(b_eq_new_b, utils::mkTrue());
+ ok = assertFactToEqualityEngine(b_eq_new_b, reason_a);
if (!ok) return false;
// assert the individual equalities as well
// a_i == b_i
@@ -152,6 +159,7 @@ bool CoreSolver::decomposeFact(TNode fact) {
Assert (new_a.getNumChildren() == new_b.getNumChildren());
for (unsigned i = 0; i < new_a.getNumChildren(); ++i) {
Node eq_i = nm->mkNode(kind::EQUAL, new_a[i], new_b[i]);
+ // this reason is not very precise!!
ok = assertFactToEqualityEngine(eq_i, reason);
d_reasons.insert(eq_i);
if (!ok) return false;
@@ -164,15 +172,16 @@ bool CoreSolver::decomposeFact(TNode fact) {
d_slicer->processEquality(fact[0]);
TNode a = fact[0][0];
TNode b = fact[0][1];
- std::vector<Node> explanation_a;
+ std::vector<TNode> explanation_a;
Node new_a = getBaseDecomposition(a, explanation_a);
Node reason_a = explanation_a.empty()? mkTrue() : mkAnd(explanation_a);
assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, a, new_a), reason_a);
- std::vector<Node> explanation_b;
+ std::vector<TNode> explanation_b;
Node new_b = getBaseDecomposition(b, explanation_b);
Node reason_b = explanation_b.empty()? mkTrue() : mkAnd(explanation_b);
assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, b, new_b), reason_b);
+
d_reasons.insert(reason_a);
d_reasons.insert(reason_b);
}
@@ -279,13 +288,16 @@ void CoreSolver::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) {
bool CoreSolver::storePropagation(TNode literal) {
return d_bv->storePropagation(literal, SUB_CORE);
}
-
+
void CoreSolver::conflict(TNode a, TNode b) {
std::vector<TNode> assumptions;
d_equalityEngine.explainEquality(a, b, true, assumptions);
- d_bv->setConflict(mkAnd(assumptions));
+ Node conflict = flattenAnd(assumptions);
+ d_bv->setConflict(conflict);
}
+
+
void CoreSolver::collectModelInfo(TheoryModel* m) {
if (Debug.isOn("bitvector-model")) {
context::CDQueue<Node>::const_iterator it = d_assertionQueue.begin();
diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h
index 4f2d7a279..868f3754f 100644
--- a/src/theory/bv/bv_subtheory_core.h
+++ b/src/theory/bv/bv_subtheory_core.h
@@ -67,7 +67,7 @@ class CoreSolver : public SubtheorySolver {
context::CDHashSet<Node, NodeHashFunction> d_reasons;
bool assertFactToEqualityEngine(TNode fact, TNode reason);
bool decomposeFact(TNode fact);
- Node getBaseDecomposition(TNode a, std::vector<Node>& explanation);
+ Node getBaseDecomposition(TNode a, std::vector<TNode>& explanation);
public:
CoreSolver(context::Context* c, TheoryBV* bv);
~CoreSolver();
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp
index 5d376ea50..b24702635 100644
--- a/src/theory/bv/slicer.cpp
+++ b/src/theory/bv/slicer.cpp
@@ -156,11 +156,11 @@ std::string NormalForm::debugPrint(const UnionFind& uf) const {
return os.str();
}
/**
- * UnionFind::Node
+ * UnionFind::EqualityNode
*
*/
-std::string UnionFind::Node::debugPrint() const {
+std::string UnionFind::EqualityNode::debugPrint() const {
ostringstream os;
os << "Repr " << d_edge.repr << " ["<< d_bitwidth << "] ";
os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl;
@@ -172,41 +172,80 @@ std::string UnionFind::Node::debugPrint() const {
* UnionFind
*
*/
-TermId UnionFind::addNode(Index bitwidth) {
+
+TermId UnionFind::registerTopLevelTerm(Index bitwidth) {
+ TermId id = mkEqualityNode(bitwidth);
+ d_topLevelIds.insert(id);
+ return id;
+}
+
+TermId UnionFind::mkEqualityNode(Index bitwidth) {
Assert (bitwidth > 0);
- Node node(bitwidth);
- d_nodes.push_back(node);
+ EqualityNode node(bitwidth);
+ d_equalityNodes.push_back(node);
++(d_statistics.d_numNodes);
- TermId id = d_nodes.size() - 1;
+ TermId id = d_equalityNodes.size() - 1;
// d_representatives.insert(id);
++(d_statistics.d_numRepresentatives);
Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl;
return id;
}
-
-TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) {
- if (isExtractTerm(topLevel)) {
- ExtractTerm top = getExtractTerm(topLevel);
- Index top_high = top.high;
- Index top_low = top.low;
- Assert (top_high - top_low + 1 > high);
- high += top_low;
- low += top_low;
- topLevel = top.id;
+/**
+ * Create an extract term making sure there are no nested extracts.
+ *
+ * @param id
+ * @param high
+ * @param low
+ *
+ * @return
+ */
+ExtractTerm UnionFind::mkExtractTerm(TermId id, Index high, Index low) {
+ if (d_topLevelIds.find(id) != d_topLevelIds.end()) {
+ return ExtractTerm(id, high, low);
}
- ExtractTerm extract(topLevel, high, low);
+ Assert (isExtractTerm(id));
+ ExtractTerm top = getExtractTerm(id);
+ Assert (d_topLevelIds.find(top.id) != d_topLevelIds.end());
+
+ Index top_high = top.high;
+ Index top_low = top.low;
+ Assert (top_high - top_low + 1 > high);
+ high += top_low;
+ low += top_low;
+ id = top.id;
+ return ExtractTerm(id, high, low);
+}
+
+/**
+ * Associate the given extract term with the given id.
+ *
+ * @param id
+ * @param extract
+ */
+void UnionFind::storeExtractTerm(TermId id, const ExtractTerm& extract) {
if (d_extractToId.find(extract) != d_extractToId.end()) {
- return d_extractToId[extract];
+ Assert (d_extractToId[extract] == id);
+ return;
}
-
- Assert (high >= low);
-
- TermId id = addNode(high - low + 1);
+ Debug("bv-slicer") << "UnionFind::storeExtract " << extract.debugPrint() << " => id" << id << "\n";
d_idToExtract[id] = extract;
d_extractToId[extract] = id;
- return id;
+ }
+
+TermId UnionFind::addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low) {
+ ExtractTerm extract(id, high, low);
+ if (d_extractToId.find(extract) != d_extractToId.end()) {
+ // if the extract already exists we don't need to make a new node
+ TermId extract_id = d_extractToId[extract];
+ Assert (extract_id < d_equalityNodes.size());
+ return extract_id;
+ }
+ // otherwise make an equality node for it and store the extract
+ TermId node_id = mkEqualityNode(bitwidth);
+ storeExtractTerm(node_id, extract);
+ return node_id;
}
/**
@@ -215,7 +254,10 @@ TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) {
* @param t1
* @param t2
*/
-void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason) {
+void UnionFind::unionTerms(TermId id1, TermId id2, TermId reason) {
+ const ExtractTerm& t1 = getExtractTerm(id1);
+ const ExtractTerm& t2 = getExtractTerm(id2);
+
Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n"
<< " " << t2.debugPrint() << "\n"
<< " with reason " << reason << endl;
@@ -294,7 +336,7 @@ TermId UnionFind::findWithExplanation(TermId id, std::vector<ExplanationId>& exp
void UnionFind::split(TermId id, Index i) {
Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl;
id = find(id);
- Debug("bv-slicer-uf") << " node: " << d_nodes[id].debugPrint() << endl;
+ Debug("bv-slicer-uf") << " node: " << d_equalityNodes[id].debugPrint() << endl;
if (i == 0 || i == getBitwidth(id)) {
// nothing to do
@@ -303,9 +345,15 @@ void UnionFind::split(TermId id, Index i) {
Assert (i < getBitwidth(id));
if (!hasChildren(id)) {
- // first time we split this term
- TermId bottom_id = addExtract(id, i - 1, 0);
- TermId top_id = addExtract(id, getBitwidth(id) - 1, i);
+ // first time we split this term
+ ExtractTerm bottom_extract = mkExtractTerm(id, i-1, 0);
+ ExtractTerm top_extract = mkExtractTerm(id, getBitwidth(id) - 1, i);
+
+ TermId bottom_id = extractHasId(bottom_extract)? getExtractId(bottom_extract) : mkEqualityNode(i);
+ TermId top_id = extractHasId(top_extract)? getExtractId(top_extract) : mkEqualityNode(getBitwidth(id) - i);
+ storeExtractTerm(bottom_id, bottom_extract);
+ storeExtractTerm(top_id, top_extract);
+
setChildren(id, top_id, bottom_id);
recordOperation(UnionFind::SPLIT, id);
@@ -471,7 +519,10 @@ void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposit
}
-void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) {
+void UnionFind::alignSlicings(TermId id1, TermId id2) {
+ const ExtractTerm& term1 = getExtractTerm(id1);
+ const ExtractTerm& term2 = getExtractTerm(id2);
+
Debug("bv-slicer") << "UnionFind::alignSlicings " << term1.debugPrint() << endl;
Debug("bv-slicer") << " " << term2.debugPrint() << endl;
NormalForm nf1(term1.getBitwidth());
@@ -519,15 +570,18 @@ void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2
}
} while (changed);
}
+
+
/**
* Given an extract term a[i:j] makes sure a is sliced
* at indices i and j.
*
* @param term
*/
-void UnionFind::ensureSlicing(const ExtractTerm& term) {
+void UnionFind::ensureSlicing(TermId t) {
+ ExtractTerm term = getExtractTerm(t);
//Debug("bv-slicer") << "Slicer::ensureSlicing " << term.debugPrint() << endl;
- TermId id = find(term.id);
+ TermId id = term.id;
split(id, term.high + 1);
split(id, term.low);
}
@@ -576,30 +630,69 @@ void UnionFind::getBase(TermId id, Base& base, Index offset) {
getBase(id0, base, offset);
}
+/// getter methods for the internal nodes
+TermId UnionFind::getRepr(TermId id) const {
+ Assert (id < d_equalityNodes.size());
+ return d_equalityNodes[id].getRepr();
+}
+ExplanationId UnionFind::getReason(TermId id) const {
+ Assert (id < d_equalityNodes.size());
+ return d_equalityNodes[id].getReason();
+}
+TermId UnionFind::getChild(TermId id, Index i) const {
+ Assert (id < d_equalityNodes.size());
+ return d_equalityNodes[id].getChild(i);
+}
+Index UnionFind::getCutPoint(TermId id) const {
+ return getBitwidth(getChild(id, 0));
+}
+bool UnionFind::hasChildren(TermId id) const {
+ Assert (id < d_equalityNodes.size());
+ return d_equalityNodes[id].hasChildren();
+}
+
+/// setter methods for the internal nodes
+void UnionFind::setRepr(TermId id, TermId new_repr, ExplanationId reason) {
+ Assert (id < d_equalityNodes.size());
+ d_equalityNodes[id].setRepr(new_repr, reason);
+}
+void UnionFind::setChildren(TermId id, TermId ch1, TermId ch0) {
+ Assert ((ch1 == UndefinedId && ch0 == UndefinedId) ||
+ (id < d_equalityNodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0)));
+ d_equalityNodes[id].setChildren(ch1, ch0);
+}
+
/**
* Slicer
*
*/
-ExtractTerm Slicer::registerTerm(TNode node) {
- Index low = 0, high = utils::getSize(node) - 1;
- TNode n = node;
+TermId Slicer::registerTerm(TNode node) {
if (node.getKind() == kind::BITVECTOR_EXTRACT) {
- n = node[0];
- high = utils::getExtractHigh(node);
- low = utils::getExtractLow(node);
- }
- if (d_nodeToId.find(n) == d_nodeToId.end()) {
- TermId id = d_unionFind.addNode(utils::getSize(n));
- d_nodeToId[n] = id;
- d_idToNode[id] = n;
+ TNode n = node[0];
+ TermId top_id = registerTopLevelTerm(n);
+ Index high = utils::getExtractHigh(node);
+ Index low = utils::getExtractLow(node);
+ TermId id = d_unionFind.addEqualityNode(utils::getSize(node), top_id, high, low);
+ return id;
+ }
+ TermId id = registerTopLevelTerm(node);
+ return id;
+}
+
+TermId Slicer::registerTopLevelTerm(TNode node) {
+ Assert (node.getKind() != kind::BITVECTOR_EXTRACT ||
+ node.getKind() != kind::BITVECTOR_CONCAT);
+
+ if (d_nodeToId.find(node) == d_nodeToId.end()) {
+ TermId id = d_unionFind.registerTopLevelTerm(utils::getSize(node));
+ d_idToNode[id] = node;
+ d_nodeToId[node] = id;
+ Debug("bv-slicer") << "Slicer::registerTopLevelTerm " << node << " => id" << id << endl;
+ return id;
}
- TermId id = d_nodeToId[n];
- d_unionFind.addExtract(id, high, low);
- ExtractTerm res(id, high, low);
- Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl;
- return res;
+ return d_nodeToId[node];
}
void Slicer::processEquality(TNode eq) {
@@ -609,42 +702,38 @@ void Slicer::processEquality(TNode eq) {
Assert (eq.getKind() == kind::EQUAL);
TNode a = eq[0];
TNode b = eq[1];
- ExtractTerm a_ex= registerTerm(a);
- ExtractTerm b_ex= registerTerm(b);
+ TermId a_id = registerTerm(a);
+ TermId b_id = registerTerm(b);
- d_unionFind.ensureSlicing(a_ex);
- d_unionFind.ensureSlicing(b_ex);
+ d_unionFind.ensureSlicing(a_id);
+ d_unionFind.ensureSlicing(b_id);
- d_unionFind.alignSlicings(a_ex, b_ex);
+ d_unionFind.alignSlicings(a_id, b_id);
- Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl;
- Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl;
- Debug("bv-slicer") << "Slicer::processEquality done. " << endl;
+ // Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl;
+ // Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl;
+ // Debug("bv-slicer") << "Slicer::processEquality done. " << endl;
}
void Slicer::assertEquality(TNode eq) {
Assert (eq.getKind() == kind::EQUAL);
- ExtractTerm a = registerTerm(eq[0]);
- ExtractTerm b = registerTerm(eq[1]);
+ TermId a = registerTerm(eq[0]);
+ TermId b = registerTerm(eq[1]);
ExplanationId reason = getExplanationId(eq);
d_unionFind.unionTerms(a, b, reason);
}
-TermId Slicer::getId(TNode node) const {
- __gnu_cxx::hash_map<Node, TermId, NodeHashFunction >::const_iterator it = d_nodeToId.find(node);
- Assert (it != d_nodeToId.end());
- return it->second;
-}
void Slicer::registerEquality(TNode eq) {
if (d_explanationToId.find(eq) == d_explanationToId.end()) {
ExplanationId id = d_explanations.size();
d_explanations.push_back(eq);
- d_explanationToId[eq] = id;
+ d_explanationToId[eq] = id;
+ Debug("bv-slicer-explanation") << "Slicer::registerEquality " << eq << " => id"<< id << "\n";
}
}
-void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation) {
+void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation) {
Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl;
Index high = utils::getSize(node) - 1;
@@ -672,13 +761,18 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::ve
Node current = getNode(nf.decomp[i]);
decomp.push_back(current);
}
-
-
- Debug("bv-slicer") << "as [";
- for (unsigned i = 0; i < decomp.size(); ++i) {
- Debug("bv-slicer") << decomp[i] <<" ";
+ if (Debug.isOn("bv-slicer-explanation")) {
+ Debug("bv-slicer-explanation") << "Slicer::getBaseDecomposition for " << node << "\n"
+ << "as ";
+ for (unsigned i = 0; i < decomp.size(); ++i) {
+ Debug("bv-slicer-explanation") << decomp[i] <<" " ;
+ }
+ Debug("bv-slicer-explanation") << "\n Explanation : \n";
+ for (unsigned i = 0; i < explanation.size(); ++i) {
+ Debug("bv-slicer-explanation") << " " << explanation[i] << "\n";
+ }
+
}
- Debug("bv-slicer") << "]" << endl;
}
@@ -754,6 +848,10 @@ void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) {
ExtractTerm UnionFind::getExtractTerm(TermId id) const {
+ if (d_topLevelIds.find(id) != d_topLevelIds.end()) {
+ // if it's a top level term so we don't have an extract stored for it
+ return ExtractTerm(id, getBitwidth(id) - 1, 0);
+ }
Assert (isExtractTerm(id));
return (d_idToExtract.find(id))->second;
@@ -763,19 +861,21 @@ bool UnionFind::isExtractTerm(TermId id) const {
return d_idToExtract.find(id) != d_idToExtract.end();
}
-bool Slicer::hasNode(TermId id) const {
+bool Slicer::isTopLevelNode(TermId id) const {
return d_idToNode.find(id) != d_idToNode.end();
}
Node Slicer::getNode(TermId id) const {
- if (hasNode(id)) {
+ if (isTopLevelNode(id)) {
return d_idToNode.find(id)->second;
}
- // otherwise must be an extract
Assert (d_unionFind.isExtractTerm(id));
- ExtractTerm extract = d_unionFind.getExtractTerm(id);
- Assert (hasNode(extract.id));
+ const ExtractTerm& extract = d_unionFind.getExtractTerm(id);
+ Assert (isTopLevelNode(extract.id));
TNode node = d_idToNode.find(extract.id)->second;
+ if (extract.high == utils::getSize(node) -1 && extract.low == 0) {
+ return node;
+ }
Node ex = utils::mkExtract(node, extract.high, extract.low);
return ex;
}
diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h
index ab2d5e88f..c46ef99ed 100644
--- a/src/theory/bv/slicer.h
+++ b/src/theory/bv/slicer.h
@@ -161,13 +161,13 @@ class UnionFind : public context::ContextNotifyObj {
{}
};
- class Node {
+ class EqualityNode {
Index d_bitwidth;
TermId d_ch1, d_ch0; // the ids of the two children if they exist
ReprEdge d_edge; // points to the representative and stores the explanation
public:
- Node(Index b)
+ EqualityNode(Index b)
: d_bitwidth(b),
d_ch1(UndefinedId),
d_ch0(UndefinedId),
@@ -189,54 +189,36 @@ class UnionFind : public context::ContextNotifyObj {
d_edge.reason = reason;
}
void setChildren(TermId ch1, TermId ch0) {
- // Assert (d_repr == UndefinedId && !hasChildren());
d_ch1 = ch1;
d_ch0 = ch0;
}
std::string debugPrint() const;
};
+
+ // the equality nodes in the union find
+ std::vector<EqualityNode> d_equalityNodes;
+
+ /// getter methods for the internal nodes
+ TermId getRepr(TermId id) const;
+ ExplanationId getReason(TermId id) const;
+ TermId getChild(TermId id, Index i) const;
+ Index getCutPoint(TermId id) const;
+ bool hasChildren(TermId id) const;
- /// map from TermId to the nodes that represent them
- std::vector<Node> d_nodes;
+ /// setter methods for the internal nodes
+ void setRepr(TermId id, TermId new_repr, ExplanationId reason);
+ void setChildren(TermId id, TermId ch1, TermId ch0);
+
+ // the mappings between ExtractTerms and ids
__gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> > d_idToExtract;
__gnu_cxx::hash_map<ExtractTerm, TermId, ExtractTermHashFunction > d_extractToId;
+
+ __gnu_cxx::hash_set<TermId> d_topLevelIds;
void getDecomposition(const ExtractTerm& term, Decomposition& decomp);
void getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, std::vector<ExplanationId>& explanation);
void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common);
- /// getter methods for the internal nodes
- TermId getRepr(TermId id) const {
- Assert (id < d_nodes.size());
- return d_nodes[id].getRepr();
- }
- ExplanationId getReason(TermId id) const {
- Assert (id < d_nodes.size());
- return d_nodes[id].getReason();
- }
- TermId getChild(TermId id, Index i) const {
- Assert (id < d_nodes.size());
- return d_nodes[id].getChild(i);
- }
- Index getCutPoint(TermId id) const {
- return getBitwidth(getChild(id, 0));
- }
- bool hasChildren(TermId id) const {
- Assert (id < d_nodes.size());
- return d_nodes[id].hasChildren();
- }
- // TermId getTopLevel(TermId id) const;
- /// setter methods for the internal nodes
- void setRepr(TermId id, TermId new_repr, ExplanationId reason) {
- Assert (id < d_nodes.size());
- d_nodes[id].setRepr(new_repr, reason);
- }
- void setChildren(TermId id, TermId ch1, TermId ch0) {
- Assert ((ch1 == UndefinedId && ch0 == UndefinedId) ||
- (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0)));
- d_nodes[id].setChildren(ch1, ch0);
- }
-
/* Backtracking mechanisms */
enum OperationKind {
@@ -271,36 +253,44 @@ class UnionFind : public context::ContextNotifyObj {
~Statistics();
};
Statistics d_statistics;
- Slicer* d_slicer;
+ Slicer* d_slicer;
+ TermId d_termIdCount;
+
+ TermId mkEqualityNode(Index bitwidth);
+ ExtractTerm mkExtractTerm(TermId id, Index high, Index low);
+ void storeExtractTerm(Index id, const ExtractTerm& term);
+ ExtractTerm getExtractTerm(TermId id) const;
+ bool extractHasId(const ExtractTerm& ex) const { return d_extractToId.find(ex) != d_extractToId.end(); }
+ TermId getExtractId(const ExtractTerm& ex) const {Assert (extractHasId(ex)); return d_extractToId.find(ex)->second; }
+ bool isExtractTerm(TermId id) const;
public:
UnionFind(context::Context* ctx, Slicer* slicer)
: ContextNotifyObj(ctx),
- d_nodes(),
+ d_equalityNodes(),
d_idToExtract(),
- d_extractToId(),
+ d_extractToId(),
+ d_topLevelIds(),
d_undoStack(),
d_undoStackIndex(ctx),
d_statistics(),
- d_slicer(slicer)
+ d_slicer(slicer),
+ d_termIdCount(0)
{}
- TermId addNode(Index bitwidth);
- TermId addExtract(Index topLevel, Index high, Index low);
- ExtractTerm getExtractTerm(TermId id) const;
- bool isExtractTerm(TermId id) const;
-
- void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason);
+ TermId addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low);
+ TermId registerTopLevelTerm(Index bitwidth);
+ void unionTerms(TermId id1, TermId id2, TermId reason);
void merge(TermId t1, TermId t2, TermId reason);
TermId find(TermId t1);
TermId findWithExplanation(TermId id, std::vector<ExplanationId>& explanation);
void split(TermId term, Index i);
void getNormalForm(const ExtractTerm& term, NormalForm& nf);
void getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, std::vector<ExplanationId>& explanation);
- void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2);
- void ensureSlicing(const ExtractTerm& term);
+ void alignSlicings(TermId id1, TermId id2);
+ void ensureSlicing(TermId id);
Index getBitwidth(TermId id) const {
- Assert (id < d_nodes.size());
- return d_nodes[id].getBitwidth();
+ Assert (id < d_equalityNodes.size());
+ return d_equalityNodes[id].getBitwidth();
}
void getBase(TermId id, Base& base, Index offset);
std::string debugPrint(TermId id);
@@ -314,17 +304,19 @@ public:
class CoreSolver;
class Slicer {
- __gnu_cxx::hash_map<TermId, TNode, __gnu_cxx::hash<TermId> > d_idToNode;
- __gnu_cxx::hash_map<Node, TermId, NodeHashFunction> d_nodeToId;
- __gnu_cxx::hash_map<Node, bool, NodeHashFunction> d_coreTermCache;
- __gnu_cxx::hash_map<Node, ExplanationId, NodeHashFunction> d_explanationToId;
- std::vector<Node> d_explanations;
+ __gnu_cxx::hash_map<TermId, TNode> d_idToNode;
+ __gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction> d_nodeToId;
+ __gnu_cxx::hash_map<TNode, bool, TNodeHashFunction> d_coreTermCache;
+ __gnu_cxx::hash_map<TNode, ExplanationId, NodeHashFunction> d_explanationToId;
+ std::vector<TNode> d_explanations;
UnionFind d_unionFind;
context::CDQueue<Node> d_newSplits;
context::CDO<unsigned> d_newSplitsIndex;
CoreSolver* d_coreSolver;
- TermId d_termIdCount;
+ TermId registerTopLevelTerm(TNode node);
+ bool isTopLevelNode(TermId id) const;
+ TermId registerTerm(TNode node);
public:
Slicer(context::Context* ctx, CoreSolver* coreSolver)
: d_idToNode(),
@@ -338,16 +330,15 @@ public:
d_coreSolver(coreSolver)
{}
- void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation);
+ void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation);
void registerEquality(TNode eq);
- ExtractTerm registerTerm(TNode node);
+
void processEquality(TNode eq);
void assertEquality(TNode eq);
bool isCoreTerm (TNode node);
bool hasNode(TermId id) const;
Node getNode(TermId id) const;
- TermId getId(TNode node) const;
bool hasExplanation(ExplanationId id) const;
TNode getExplanation(ExplanationId id) const;
diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h
index e5a7bbb84..98bc8041d 100644
--- a/src/theory/bv/theory_bv_utils.h
+++ b/src/theory/bv/theory_bv_utils.h
@@ -69,28 +69,6 @@ inline Node mkVar(unsigned size) {
return nm->mkSkolem("bv_$$", nm->mkBitVectorType(size), "is a variable created by the theory of bitvectors");
}
-inline Node mkAnd(std::vector<TNode>& children) {
- std::set<TNode> distinctChildren;
- distinctChildren.insert(children.begin(), children.end());
-
- if (distinctChildren.size() == 0) {
- return mkTrue();
- }
-
- if (distinctChildren.size() == 1) {
- return *children.begin();
- }
-
- NodeBuilder<> conjunction(kind::AND);
- std::set<TNode>::const_iterator it = distinctChildren.begin();
- std::set<TNode>::const_iterator it_end = distinctChildren.end();
- while (it != it_end) {
- conjunction << *it;
- ++ it;
- }
-
- return conjunction;
-}
inline Node mkSortedNode(Kind kind, std::vector<Node>& children) {
Assert (kind == kind::BITVECTOR_AND ||
@@ -155,14 +133,6 @@ inline Node mkXor(TNode node1, TNode node2) {
}
-inline Node mkAnd(std::vector<Node>& children) {
- if(children.size() > 1) {
- return NodeManager::currentNM()->mkNode(kind::AND, children);
- } else {
- return children[0];
- }
-}
-
inline Node mkExtract(TNode node, unsigned high, unsigned low) {
Node extractOp = NodeManager::currentNM()->mkConst<BitVectorExtract>(BitVectorExtract(high, low));
std::vector<Node> children;
@@ -268,7 +238,6 @@ inline Node mkConjunction(const std::set<TNode> nodes) {
return conjunction;
}
-
inline unsigned isPow2Const(TNode node) {
if (node.getKind() != kind::CONST_BITVECTOR) {
return false;
@@ -278,6 +247,83 @@ inline unsigned isPow2Const(TNode node) {
return bv.isPow2();
}
+typedef __gnu_cxx::hash_set<TNode, TNodeHashFunction> TNodeSet;
+
+inline Node mkAnd(const std::vector<TNode>& conjunctions) {
+ std::set<TNode> all;
+ all.insert(conjunctions.begin(), conjunctions.end());
+
+ if (all.size() == 0) {
+ return mkTrue();
+ }
+
+ if (all.size() == 1) {
+ // All the same, or just one
+ return conjunctions[0];
+ }
+
+
+ NodeBuilder<> conjunction(kind::AND);
+ std::set<TNode>::const_iterator it = all.begin();
+ std::set<TNode>::const_iterator it_end = all.end();
+ while (it != it_end) {
+ conjunction << *it;
+ ++ it;
+ }
+
+ return conjunction;
+}/* mkAnd() */
+
+inline Node mkAnd(const std::vector<Node>& conjunctions) {
+ std::set<TNode> all;
+ all.insert(conjunctions.begin(), conjunctions.end());
+
+ if (all.size() == 0) {
+ return mkTrue();
+ }
+
+ if (all.size() == 1) {
+ // All the same, or just one
+ return conjunctions[0];
+ }
+
+
+ NodeBuilder<> conjunction(kind::AND);
+ std::set<TNode>::const_iterator it = all.begin();
+ std::set<TNode>::const_iterator it_end = all.end();
+ while (it != it_end) {
+ conjunction << *it;
+ ++ it;
+ }
+
+ return conjunction;
+}/* mkAnd() */
+
+
+
+inline Node flattenAnd(std::vector<TNode>& queue) {
+ TNodeSet nodes;
+ while(!queue.empty()) {
+ TNode current = queue.back();
+ queue.pop_back();
+ if (current.getKind() == kind::AND) {
+ for (unsigned i = 0; i < current.getNumChildren(); ++i) {
+ if (nodes.count(current[i]) == 0) {
+ queue.push_back(current[i]);
+ }
+ }
+ } else {
+ nodes.insert(current);
+ }
+ }
+ std::vector<TNode> children;
+ for (TNodeSet::const_iterator it = nodes.begin(); it!= nodes.end(); ++it) {
+ children.push_back(*it);
+ }
+ return mkAnd(children);
+}
+
+
// neeed a better name, this is not technically a ground term
inline bool isBVGroundTerm(TNode node) {
if (node.getNumChildren() == 0) {
@@ -356,27 +402,7 @@ inline Node mkConjunction(const std::vector<TNode>& nodes) {
}
-inline Node mkAnd(const std::vector<TNode>& conjunctions) {
- Assert(conjunctions.size() > 0);
-
- std::set<TNode> all;
- all.insert(conjunctions.begin(), conjunctions.end());
- if (all.size() == 1) {
- // All the same, or just one
- return conjunctions[0];
- }
-
- NodeBuilder<> conjunction(kind::AND);
- std::set<TNode>::const_iterator it = all.begin();
- std::set<TNode>::const_iterator it_end = all.end();
- while (it != it_end) {
- conjunction << *it;
- ++ it;
- }
-
- return conjunction;
-}/* mkAnd() */
// Turn a set into a string
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback