summaryrefslogtreecommitdiff
path: root/src/theory/bv/bv_inequality_graph.cpp
diff options
context:
space:
mode:
authorLiana Hadarean <lianahady@gmail.com>2013-03-19 21:54:22 -0400
committerLiana Hadarean <lianahady@gmail.com>2013-03-19 21:54:22 -0400
commit4cd63abf2ab901ad8d1b1c2cc2e84707736b5659 (patch)
treeb45789d51329bbfdf0043f9fcb577ea0fb2c38bc /src/theory/bv/bv_inequality_graph.cpp
parentd58d78b3ac3e5abfaa4e01d87bb351c0268239df (diff)
inequality reasoning works on small examples added to regressions (not incremental); currently disabled though
Diffstat (limited to 'src/theory/bv/bv_inequality_graph.cpp')
-rw-r--r--src/theory/bv/bv_inequality_graph.cpp323
1 files changed, 228 insertions, 95 deletions
diff --git a/src/theory/bv/bv_inequality_graph.cpp b/src/theory/bv/bv_inequality_graph.cpp
index 2821fe5e1..7351abe4d 100644
--- a/src/theory/bv/bv_inequality_graph.cpp
+++ b/src/theory/bv/bv_inequality_graph.cpp
@@ -24,37 +24,60 @@ using namespace CVC4::theory;
using namespace CVC4::theory::bv;
using namespace CVC4::theory::bv::utils;
+const TermId CVC4::theory::bv::UndefinedTermId = -1;
+const ReasonId CVC4::theory::bv::UndefinedReasonId = -1;
+
+
bool InequalityGraph::addInequality(TNode a, TNode b, TNode reason) {
+ Debug("bv-inequality") << "InequlityGraph::addInequality " << a << " " << b << "\n";
TermId id_a = registerTerm(a);
TermId id_b = registerTerm(b);
ReasonId id_reason = registerReason(reason);
- return addInequalityInternal(id_a, id_b, id_reason);
-}
-
-bool InequalityGraph::addInequalityInternal(TermId a, TermId b, TermId reason) {
- if (getValue(a) < getValue(b)) {
- // the inequality is true in the current partial model
- return true;
- }
- if (getValue(b) < getValue(a)) {
- // the inequality is false in the current partial model
- std::vector<ReasonId> conflict;
- computeExplanation(b, a, conflict);
- return false;
+ if (hasValue(id_a) && hasValue(id_b)) {
+ if (getValue(id_a) < getValue(id_b)) {
+ // the inequality is true in the current partial model
+ // we still add the edge because it may not be true later (cardinality)
+ addEdge(id_a, id_b, id_reason);
+ return true;
+ }
+ if (canReach(id_b, id_a)) {
+ // we are introducing a cycle; make sure the model reflects this
+ Assert (getValue(id_a) >= getValue(id_b));
+
+ std::vector<ReasonId> conflict;
+ conflict.push_back(id_reason);
+ computeExplanation(id_b, id_a, conflict);
+ setConflict(conflict);
+ return false;
+ }
+ } else {
+ bool ok = initializeValues(a, b, id_reason);
+ if (!ok) {
+ return false;
+ }
}
// the inequality edge does not exist
- addEdge(a, b, reason);
+ addEdge(id_a, id_b, id_reason);
BFSQueue queue;
- queue.push(a);
- return computeValuesBFS(queue);
+ queue.push(PQueueElement(id_a, getValue(id_a)));
+ TermIdSet seen;
+ return computeValuesBFS(queue, seen);
}
-void InequalityGraph::computeConflict(TermId from, TermId to, std::vector<ReasonId>& explanation) {
- if (to == from)
+void InequalityGraph::computeExplanation(TermId from, TermId to, std::vector<ReasonId>& explanation) {
+ if (to == from || (from == UndefinedTermId && getInequalityNode(to).isConstant())) {
+ // we have explained the whole path or reached a constant that forced the value of to
return;
+ }
+
const Edges& edges = getInEdges(to);
- BitVector max(getBitwidth(a), 0);
+ if (edges.size() == 0) {
+ // this can happen when from is Undefined
+ Assert (from == UndefinedTermId);
+ return;
+ }
+ BitVector max(getBitwidth(to), 0u);
TermId to_visit = UndefinedTermId;
ReasonId reason = UndefinedReasonId;
@@ -65,61 +88,114 @@ void InequalityGraph::computeConflict(TermId from, TermId to, std::vector<Reason
return;
}
if (getValue(next) >= max) {
- max = it->value;
+ max = getValue(next);
to_visit = it->next;
reason = it->reason;
}
}
Assert(reason != UndefinedReasonId && to_visit != UndefinedTermId);
explanation.push_back(reason);
- computeConflict(from, to_visit, explanation);
+ computeExplanation(from, to_visit, explanation);
}
void InequalityGraph::addEdge(TermId a, TermId b, TermId reason) {
- Edges& out_edges = getEdges(a);
- edges.push_back(InequalityEdge(b, reason));
- Edges& in_edges = getParentEdges(b);
- edges.push_back(InequalityEdge(a, reason));
+ Edges& out_edges = getOutEdges(a);
+ out_edges.push_back(InequalityEdge(b, reason));
+ Edges& in_edges = getInEdges(b);
+ in_edges.push_back(InequalityEdge(a, reason));
+}
+
+TermId InequalityGraph::getMaxParent(TermId id) {
+ const Edges& in_edges = getInEdges(id);
+ Assert (in_edges.size() != 0);
+
+ BitVector max(getBitwidth(0), 0u);
+ TermId max_id = UndefinedTermId;
+ for (Edges::const_iterator it = in_edges.begin(); it!= in_edges.end(); ++it) {
+ // Assert (seen.count(it->next));
+ const BitVector& value = getInequalityNode(it->next).getValue();
+ if (value >= max) {
+ max = value;
+ max_id = it->next;
+ }
+ }
+ Assert (max_id != UndefinedTermId);
+ return max_id;
}
-bool InequalityGraph::computeValuesBFS(BitVector& min_val, BFSQueue& queue, TermIdSet& seen) {
+bool InequalityGraph::hasParents(TermId id) {
+ return getInEdges(id).size() != 0;
+}
+
+TermId InequalityGraph::getReasonId(TermId a, TermId b) {
+ const Edges& edges = getOutEdges(a);
+ for (Edges::const_iterator it = edges.begin(); it!= edges.end(); ++it) {
+ if (it->next == b) {
+ return it->reason;
+ }
+ }
+ Unreachable();
+}
+
+bool InequalityGraph::computeValuesBFS(BFSQueue& queue, TermIdSet& seen) {
if (queue.empty())
return true;
-
+
TermId current = queue.top().id;
seen.insert(current);
queue.pop();
InequalityNode& ineqNode = getInequalityNode(current);
+ Debug("bv-inequality-internal") << "InequalityGraph::computeValueBFS \n";
+ Debug("bv-inequality-internal") << " processing " << getTerm(current) << "\n"
+ << " old value " << ineqNode.getValue() << "\n";
+ BitVector zero(getBitwidth(current), 0u);
+ BitVector one(getBitwidth(current), 1u);
+ const BitVector min_val = hasParents(current) ? getInequalityNode(getMaxParent(current)).getValue() + one : zero;
+ Debug("bv-inequality-internal") << " min value " << min_val << "\n";
+
if (ineqNode.isConstant()) {
if (ineqNode.getValue() < min_val) {
- // we have a conflict
+ Debug("bv-inequality") << "Conflict: constant " << ineqNode.getValue() << "\n";
+ std::vector<ReasonId> conflict;
+ TermId max_parent = getMaxParent(current);
+ ReasonId reason_id = getReasonId(max_parent, current);
+ conflict.push_back(reason_id);
+ computeExplanation(UndefinedTermId, max_parent, conflict);
+ setConflict(conflict);
return false;
}
} else {
// if not constant we can update the value
if (ineqNode.getValue() < min_val) {
+ Debug("bv-inequality-internal") << "Updating " << getTerm(current) <<
+ " from " << ineqNode.getValue() << "\n" <<
+ " to " << min_val << "\n";
ineqNode.setValue(min_val);
}
}
- BitVector next_min = ineqNode.getValue() + 1;
+ unsigned bitwidth = min_val.getSize();
+ BitVector next_min = ineqNode.getValue() + BitVector(bitwidth, 1u);
bool overflow = next_min < min_val;
- const Edges& edges = getEdges(current);
+ const Edges& edges = getOutEdges(current);
if (edges.size() > 0 && overflow) {
// we have reached the maximum value
- computeConflict();
+ Debug("bv-inequality") << "Conflict: overflow: " << getTerm(current) << "\n";
+ std::vector<ReasonId> conflict;
+ computeExplanation(UndefinedTermId, current, conflict);
+ setConflict(conflict);
return false;
}
- // TODO: update key, maybe
+
for (Edges::const_iterator it = edges.begin(); it!= edges.end(); ++it) {
TermId next = it->next;
- if (!seen.contains(next)) {
- BitVector& value = getInequalityNode(next).getValue();
+ if (!seen.count(next)) {
+ const BitVector& value = getInequalityNode(next).getValue();
queue.push(PQueueElement(next, value));
}
}
- return computeValuesBFS(next_min, queue, seen);
+ return computeValuesBFS(queue, seen);
}
@@ -127,7 +203,7 @@ bool InequalityGraph::areLessThanInternal(TermId a, TermId b) {
return getValue(a) < getValue(b);
}
-TermId InequalitySolver::registerTerm(TNode term) {
+TermId InequalityGraph::registerTerm(TNode term) {
if (d_termNodeToIdMap.find(term) != d_termNodeToIdMap.end()) {
return d_termNodeToIdMap[term];
}
@@ -139,7 +215,7 @@ TermId InequalitySolver::registerTerm(TNode term) {
// create InequalityNode
bool isConst = term.getKind() == kind::CONST_BITVECTOR;
- BitVector value = isConst? term.getConst<BitVector>() : BitVector(utils::getSize(term),0);
+ BitVector value(0,0u); // leaves the value unintialized at this time
InequalityNode ineq = InequalityNode(id, utils::getSize(term), isConst, value);
Assert (d_ineqNodes.size() == id);
d_ineqNodes.push_back(ineq);
@@ -147,10 +223,12 @@ TermId InequalitySolver::registerTerm(TNode term) {
d_ineqEdges.push_back(Edges());
Assert(d_parentEdges.size() == id);
d_parentEdges.push_back(Edges());
+ Debug("bv-inequality-internal") << "InequalityGraph::registerTerm " << term << "\n"
+ << "with id " << id << "\n";
return id;
}
-ReasonId InequalitySolver::registerReason(TNode reason) {
+ReasonId InequalityGraph::registerReason(TNode reason) {
if (d_reasonToIdMap.find(reason) != d_reasonToIdMap.end()) {
return d_reasonToIdMap[reason];
}
@@ -160,58 +238,68 @@ ReasonId InequalitySolver::registerReason(TNode reason) {
return id;
}
-TNode InequalitySolver::getReason(ReasonId id) const {
+TNode InequalityGraph::getReason(ReasonId id) const {
Assert (d_reasonNodes.size() > id);
return d_reasonNodes[id];
}
-TNode InequalitySolver::getTerm(TermId id) const {
+TNode InequalityGraph::getTerm(TermId id) const {
Assert (d_termNodes.size() > id);
return d_termNodes[id];
}
-void InequalitySolver::setConflict(const std::vector<ReasonId>& conflict) {
+void InequalityGraph::setConflict(const std::vector<ReasonId>& conflict) {
Assert (!d_inConflict);
d_inConflict = true;
d_conflict.clear();
for (unsigned i = 0; i < conflict.size(); ++i) {
d_conflict.push_back(getReason(conflict[i]));
}
+ if (Debug.isOn("bv-inequality")) {
+ Debug("bv-inequality") << "InequalityGraph::setConflict \n";
+ for (unsigned i = 0; i < d_conflict.size(); ++i) {
+ Debug("bv-inequality") << " " << d_conflict[i] <<"\n";
+ }
+ }
}
-void InequalitySolver::getConflict(std::vector<TNode>& conflict) {
- for (unsigned i = 0; i < d_conflict.size(); ++it) {
+void InequalityGraph::getConflict(std::vector<TNode>& conflict) {
+ for (unsigned i = 0; i < d_conflict.size(); ++i) {
conflict.push_back(d_conflict[i]);
}
}
-// bool InequalityGraph::canReach(TermId from, TermId to) {
-// TermIdSet visited;
-// bfs(start, seen);
-// if (seen.constains(to)) {
-// return true;
-// }
-// }
+bool InequalityGraph::canReach(TermId from, TermId to) {
+ if (from == to )
+ return true;
+
+ TermIdSet seen;
+ TermIdQueue queue;
+ queue.push(from);
+ bfs(seen, queue);
+ if (seen.count(to)) {
+ return true;
+ }
+ return false;
+}
-// bool InequalityGraph::bfs(TermId to, TermIdSet& seen, TermIdQueue& queue) {
-// if (queue.empty())
-// return;
+void InequalityGraph::bfs(TermIdSet& seen, TermIdQueue& queue) {
+ if (queue.empty())
+ return;
-// TermId current = queue.front();
-// queue.pop();
-// if (current = to) {
-// return true;
-// }
-// const Edges& edges = getEdges(current);
-// for (Edges::const_iterator it = edges.begin(); it!= edges.end(); ++it) {
-// TermId next = it->next;
-// if(!seen.contains(next)) {
-// seen.insert(next);
-// queue.push(next);
-// }
-// }
-// return bfs(seen, queue);
-// }
+ TermId current = queue.front();
+ queue.pop();
+
+ const Edges& edges = getOutEdges(current);
+ for (Edges::const_iterator it = edges.begin(); it!= edges.end(); ++it) {
+ TermId next = it->next;
+ if(seen.count(next) == 0) {
+ seen.insert(next);
+ queue.push(next);
+ }
+ }
+ bfs(seen, queue);
+}
// void InequalityGraph::getPath(TermId to, TermId from, const TermIdSet& seen, std::vector<ReasonId> explanation) {
// // traverse parent edges
@@ -225,30 +313,75 @@ void InequalitySolver::getConflict(std::vector<TNode>& conflict) {
// }
// }
-// bool InequalityGraph::initializeValues(TNode a, TNode b) {
-// TermId id_a = registerTerm(a);
-// TermId id_b = registerTerm(b);
-// if (!hasValue(id_a) && !hasValue(id_b)) {
-// InequalityNode& ineq_a = getInequalityNode(id_a);
-// ineq_a.setValue(BiVector(utils::getSize(a), 0));
-// InequalityNode& ineq_b = getInequalityNode(id_b);
-// ineq_a.setValue(BiVector(utils::getSize(a), 1));
-// }
-// if (!hasValue(id_a) && hasValue(id_b)) {
-// BitVector& b_value = getValue(id_b);
-// if (b_value == 0) {
-// return false;
-// }
-// InequalityNode& ineq_a = getInequalityNode(id_a);
-// ineq_a.setValue(b_value - 1);
-// }
-// if (hasValue(id_a) && !hasValue(id_b)) {
-// BitVector& a_value = getValue(id_a);
-// if (a_value + 1 < a_value) {
-// return false;
-// }
-// InequalityNode& ineq_b = getInequalityNode(id_b);
-// ineq_b.setValue(a_value + 1);
-// }
-// return true;
-// }
+bool InequalityGraph::hasValue(TermId id) const {
+ return getInequalityNode(id).getValue() != BitVector(0, 0u);
+}
+
+bool InequalityGraph::initializeValues(TNode a, TNode b, TermId reason_id) {
+ TermId id_a = registerTerm(a);
+ TermId id_b = registerTerm(b);
+
+ InequalityNode& ineq_a = getInequalityNode(id_a);
+ InequalityNode& ineq_b = getInequalityNode(id_b);
+ // FIXME: dumb case splitting
+ if (ineq_a.isConstant() && ineq_b.isConstant()) {
+ Assert (a.getConst<BitVector>() < b.getConst<BitVector>());
+ ineq_a.setValue(a.getConst<BitVector>());
+ ineq_b.setValue(b.getConst<BitVector>());
+ return true;
+ }
+
+ if (ineq_a.isConstant()) {
+ ineq_a.setValue(a.getConst<BitVector>());
+ }
+ if (ineq_b.isConstant()) {
+ const BitVector& const_val = b.getConst<BitVector>();
+ ineq_b.setValue(const_val);
+ // check for potential underflow
+ if (hasValue(id_a) && ineq_a.getValue() > const_val) {
+ // must be a conflict because we have as an invariant that a will have the min
+ // possible value for a.
+ std::vector<ReasonId> conflict;
+ conflict.push_back(reason_id);
+ // FIXME: this will not compute the most precise conflict
+ // could be fixed by giving computeExplanation a bound (i.e. the size of const_val)
+ computeExplanation(UndefinedTermId, id_a, conflict);
+ setConflict(conflict);
+ return false;
+ }
+ }
+
+ BitVector one(getBitwidth(id_a), 1u);
+ BitVector zero(getBitwidth(id_a), 0u);
+
+ if (!hasValue(id_a) && !hasValue(id_b)) {
+ // initialize to the minimum possible values
+ ineq_a.setValue(zero);
+ ineq_b.setValue(one);
+ }
+ else if (!hasValue(id_a) && hasValue(id_b)) {
+ const BitVector& b_value = ineq_b.getValue();
+ if (b_value == zero) {
+ if (ineq_b.isConstant()) {
+ Debug("bv-inequality") << "Conflict: underflow " << getTerm(id_a) <<"\n";
+ std::vector<ReasonId> conflict;
+ conflict.push_back(reason_id);
+ setConflict(conflict);
+ return false;
+ }
+ // otherwise we attempt to increment b
+ ineq_b.setValue(one);
+ }
+ // if a has no value then we can assign it to whatever we want
+ // to maintain the invariant that each value has the lowest value
+ // we assign it to zero
+ ineq_a.setValue(zero);
+ } else if (hasValue(id_a) && !hasValue(id_b)) {
+ const BitVector& a_value = ineq_a.getValue();
+ if (a_value + one < a_value) {
+ return false;
+ }
+ ineq_b.setValue(a_value + one);
+ }
+ return true;
+}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback