summaryrefslogtreecommitdiff
path: root/src/theory
diff options
context:
space:
mode:
authorLiana Hadarean <lianahady@gmail.com>2013-03-24 23:38:33 -0400
committerLiana Hadarean <lianahady@gmail.com>2013-03-24 23:38:33 -0400
commit147f93cc140b1cf2a5957cbe95eccfc92e4d90b0 (patch)
tree985ec319875036d27079763865a4d15cc29018f0 /src/theory
parentab19f7ee3cd09d9e9bbf3a75f54989e132442ccf (diff)
added support for disequalities in the inequality solver
Diffstat (limited to 'src/theory')
-rw-r--r--src/theory/bv/bv_inequality_graph.cpp99
-rw-r--r--src/theory/bv/bv_inequality_graph.h29
-rw-r--r--src/theory/bv/bv_subtheory_inequality.cpp12
-rw-r--r--src/theory/bv/theory_bv.h2
4 files changed, 125 insertions, 17 deletions
diff --git a/src/theory/bv/bv_inequality_graph.cpp b/src/theory/bv/bv_inequality_graph.cpp
index 704f99039..4bd315872 100644
--- a/src/theory/bv/bv_inequality_graph.cpp
+++ b/src/theory/bv/bv_inequality_graph.cpp
@@ -261,9 +261,27 @@ void InequalityGraph::addEdge(TermId a, TermId b, bool strict, TermId reason) {
d_undoStackIndex = d_undoStackIndex + 1;
}
+void InequalityGraph::initializeModelValue(TNode node) {
+ TermId id = getTermId(node);
+ Assert (!hasModelValue(id));
+ bool isConst = node.getKind() == kind::CONST_BITVECTOR;
+ unsigned size = utils::getSize(node);
+ BitVector value = isConst? node.getConst<BitVector>() : BitVector(size, 0u);
+ setModelValue(id, ModelValue(value, UndefinedTermId, UndefinedReasonId));
+}
+
+bool InequalityGraph::isRegistered(TNode term) const {
+ return d_termNodeToIdMap.find(term) != d_termNodeToIdMap.end();
+}
+
TermId InequalityGraph::registerTerm(TNode term) {
if (d_termNodeToIdMap.find(term) != d_termNodeToIdMap.end()) {
- return d_termNodeToIdMap[term];
+ TermId id = d_termNodeToIdMap[term];
+ if (!hasModelValue(id)) {
+ // we could have backtracked and
+ initializeModelValue(term);
+ }
+ return id;
}
// store in node mapping
@@ -275,21 +293,17 @@ TermId InequalityGraph::registerTerm(TNode term) {
// create InequalityNode
unsigned size = utils::getSize(term);
+
bool isConst = term.getKind() == kind::CONST_BITVECTOR;
- BitVector value = isConst? term.getConst<BitVector>() : BitVector(size, 0u);
-
InequalityNode ineq = InequalityNode(id, size, isConst);
- setModelValue(id, ModelValue(value, UndefinedTermId, UndefinedReasonId));
-
+
Assert (d_ineqNodes.size() == id);
d_ineqNodes.push_back(ineq);
Assert (d_ineqEdges.size() == id);
d_ineqEdges.push_back(Edges());
- // add the default edges min <= term <= max
- // addEdge(getMinValueId(size), id, false, AxiomReasonId);
- // addEdge(id, getMaxValueId(size), false, AxiomReasonId);
+ initializeModelValue(term);
return id;
}
@@ -314,6 +328,11 @@ TNode InequalityGraph::getTermNode(TermId id) const {
return d_termNodes[id];
}
+TermId InequalityGraph::getTermId(TNode node) const {
+ Assert (d_termNodeToIdMap.find(node) != d_termNodeToIdMap.end());
+ return d_termNodeToIdMap.find(node)->second;
+}
+
void InequalityGraph::setConflict(const std::vector<ReasonId>& conflict) {
Assert (!d_inConflict);
d_inConflict = true;
@@ -351,8 +370,8 @@ bool InequalityGraph::hasModelValue(TermId id) const {
}
BitVector InequalityGraph::getValue(TermId id) const {
- Assert (hasModelValue(id));
- BitVector res = (*(d_modelValues.find(id))).second.value;
+ Assert (hasModelValue(id));
+ BitVector res = (*(d_modelValues.find(id))).second.value;
return res;
}
@@ -361,6 +380,66 @@ bool InequalityGraph::hasReason(TermId id) const {
return mv.reason != UndefinedReasonId;
}
+bool InequalityGraph::addDisequality(TNode a, TNode b, TNode reason) {
+ Debug("bv-inequality") << "InequalityGraph::addDisequality " << reason << "\n";
+ d_disequalities.push_back(reason);
+
+ if (!isRegistered(a) || !isRegistered(b)) {
+ splitDisequality(reason);
+ return true;
+ }
+ TermId id_a = getTermId(a);
+ TermId id_b = getTermId(b);
+ if (!hasModelValue(id_a)) {
+ initializeModelValue(a);
+ }
+ if (!hasModelValue(id_b)) {
+ initializeModelValue(b);
+ }
+ const BitVector& val_a = getValue(id_a);
+ const BitVector& val_b = getValue(id_b);
+ if (val_a == val_b) {
+ if (a.getKind() == kind::CONST_BITVECTOR) {
+ // then we know b cannot be smaller than the assigned value so we try to make it larger
+ return addInequality(a, b, true, reason);
+ }
+ if (b.getKind() == kind::CONST_BITVECTOR) {
+ return addInequality(b, a, true, reason);
+ }
+ // if none of the terms are constants just add the lemma
+ splitDisequality(reason);
+ } else {
+ Debug("bv-inequality-internal") << "Disequal: " << a << " => " << val_a.toString(10) << "\n"
+ << " " << b << " => " << val_b.toString(10) << "\n";
+ }
+ return true;
+}
+
+void InequalityGraph::splitDisequality(TNode diseq) {
+ Debug("bv-inequality-internal")<<"InequalityGraph::splitDisequality " << diseq <<"\n";
+ Assert (diseq.getKind() == kind::NOT && diseq[0].getKind() == kind::EQUAL);
+ TNode a = diseq[0][0];
+ TNode b = diseq[0][1];
+ Node a_lt_b = utils::mkNode(kind::BITVECTOR_ULT, a, b);
+ Node b_lt_a = utils::mkNode(kind::BITVECTOR_ULT, b, a);
+ Node split = utils::mkNode(kind::OR, a_lt_b, b_lt_a);
+ Node lemma = utils::mkNode(kind::IMPLIES, diseq, split);
+ if (d_lemmasAdded.find(lemma) == d_lemmasAdded.end()) {
+ d_lemmaQueue.push_back(lemma);
+ }
+}
+
+void InequalityGraph::getNewLemmas(std::vector<TNode>& new_lemmas) {
+ for (unsigned i = d_lemmaIndex; i < d_lemmaQueue.size(); ++i) {
+ TNode lemma = d_lemmaQueue[i];
+ if (d_lemmasAdded.find(lemma) == d_lemmasAdded.end()) {
+ new_lemmas.push_back(lemma);
+ d_lemmasAdded.insert(lemma);
+ }
+ d_lemmaIndex = d_lemmaIndex + 1;
+ }
+}
+
std::string InequalityGraph::PQueueElement::toString() const {
ostringstream os;
os << "(id: " << id << ", lower_bound: " << lower_bound.toString(10) <<", old_value: " << model_value.value.toString(10) << ")";
diff --git a/src/theory/bv/bv_inequality_graph.h b/src/theory/bv/bv_inequality_graph.h
index 57e59f6f5..1335eff93 100644
--- a/src/theory/bv/bv_inequality_graph.h
+++ b/src/theory/bv/bv_inequality_graph.h
@@ -111,7 +111,7 @@ class InequalityGraph : public context::ContextNotifyObj{
typedef __gnu_cxx::hash_set<TermId> TermIdSet;
typedef std::priority_queue<PQueueElement> BFSQueue;
-
+ typedef __gnu_cxx::hash_set<TNode, TNodeHashFunction> TNodeSet;
std::vector<InequalityNode> d_ineqNodes;
std::vector< Edges > d_ineqEdges;
@@ -125,7 +125,8 @@ class InequalityGraph : public context::ContextNotifyObj{
std::vector<TNode> d_conflict;
bool d_signed;
- context::CDHashMap<TermId, ModelValue> d_modelValues;
+ context::CDHashMap<TermId, ModelValue> d_modelValues;
+ void initializeModelValue(TNode node);
void setModelValue(TermId term, const ModelValue& mv);
ModelValue getModelValue(TermId term) const;
bool hasModelValue(TermId id) const;
@@ -142,7 +143,8 @@ class InequalityGraph : public context::ContextNotifyObj{
TermId registerTerm(TNode term);
TNode getTermNode(TermId id) const;
TermId getTermId(TNode node) const;
-
+ bool isRegistered(TNode term) const;
+
ReasonId registerReason(TNode reason);
TNode getReasonNode(ReasonId id) const;
@@ -152,10 +154,6 @@ class InequalityGraph : public context::ContextNotifyObj{
const InequalityNode& getInequalityNode(TermId id) const { Assert (id < d_ineqNodes.size()); return d_ineqNodes[id]; }
unsigned getBitwidth(TermId id) const { return getInequalityNode(id).getBitwidth(); }
bool isConst(TermId id) const { return getInequalityNode(id).isConstant(); }
- // BitVector maxValue(unsigned bitwidth);
- // BitVector minValue(unsigned bitwidth);
- // TermId getMaxValueId(unsigned bitwidth);
- // TermId getMinValueId(unsigned bitwidth);
BitVector getValue(TermId id) const;
@@ -191,7 +189,18 @@ class InequalityGraph : public context::ContextNotifyObj{
* @param explanation
*/
void computeExplanation(TermId from, TermId to, std::vector<ReasonId>& explanation);
+ void splitDisequality(TNode diseq);
+ /**
+ Disequality reasoning
+ */
+
+ /*** The currently asserted disequalities */
+ context::CDQueue<TNode> d_disequalities;
+ context::CDQueue<Node> d_lemmaQueue;
+ context::CDO<unsigned> d_lemmaIndex;
+ TNodeSet d_lemmasAdded;
+
/** Backtracking mechanisms **/
std::vector<std::pair<TermId, InequalityEdge> > d_undoStack;
context::CDO<unsigned> d_undoStackIndex;
@@ -213,6 +222,10 @@ public:
d_conflict(),
d_signed(s),
d_modelValues(c),
+ d_disequalities(c),
+ d_lemmaQueue(c),
+ d_lemmaIndex(c, 0),
+ d_lemmasAdded(),
d_undoStack(),
d_undoStackIndex(c)
{}
@@ -227,9 +240,11 @@ public:
* @return
*/
bool addInequality(TNode a, TNode b, bool strict, TNode reason);
+ bool addDisequality(TNode a, TNode b, TNode reason);
bool areLessThan(TNode a, TNode b);
void getConflict(std::vector<TNode>& conflict);
virtual ~InequalityGraph() throw(AssertionException) {}
+ void getNewLemmas(std::vector<TNode>& new_lemmas);
};
}
diff --git a/src/theory/bv/bv_subtheory_inequality.cpp b/src/theory/bv/bv_subtheory_inequality.cpp
index f856c9410..6b9842e8f 100644
--- a/src/theory/bv/bv_subtheory_inequality.cpp
+++ b/src/theory/bv/bv_subtheory_inequality.cpp
@@ -27,15 +27,21 @@ using namespace CVC4::theory::bv;
using namespace CVC4::theory::bv::utils;
bool InequalitySolver::check(Theory::Effort e) {
+ Debug("bv-subtheory-inequality") << "InequalitySolveR::check("<< e <<")\n";
bool ok = true;
while (!done() && ok) {
TNode fact = get();
+ Debug("bv-subtheory-inequality") << " "<< fact <<"\n";
if (fact.getKind() == kind::EQUAL) {
TNode a = fact[0];
TNode b = fact[1];
ok = d_inequalityGraph.addInequality(a, b, false, fact);
if (ok)
ok = d_inequalityGraph.addInequality(b, a, false, fact);
+ } else if (fact.getKind() == kind::NOT && fact[0].getKind() == kind::EQUAL) {
+ TNode a = fact[0][0];
+ TNode b = fact[0][1];
+ ok = d_inequalityGraph.addDisequality(a, b, fact);
}
if (fact.getKind() == kind::NOT && fact[0].getKind() == kind::BITVECTOR_ULE) {
TNode a = fact[0][1];
@@ -61,6 +67,12 @@ bool InequalitySolver::check(Theory::Effort e) {
d_bv->setConflict(utils::mkConjunction(conflict));
return false;
}
+ // send out any lemmas
+ std::vector<TNode> lemmas;
+ d_inequalityGraph.getNewLemmas(lemmas);
+ for(unsigned i = 0; i < lemmas.size(); ++i) {
+ d_bv->lemma(lemmas[i]);
+ }
return true;
}
diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h
index 13a475d3d..54260deb9 100644
--- a/src/theory/bv/theory_bv.h
+++ b/src/theory/bv/theory_bv.h
@@ -137,6 +137,8 @@ private:
void sendConflict();
+ void lemma(TNode node) { d_out->lemma(node); }
+
friend class Bitblaster;
friend class BitblastSolver;
friend class EqualitySolver;
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback