diff options
Diffstat (limited to 'src/theory/bv/bv_inequality_graph.h')
-rw-r--r-- | src/theory/bv/bv_inequality_graph.h | 82 |
1 files changed, 48 insertions, 34 deletions
diff --git a/src/theory/bv/bv_inequality_graph.h b/src/theory/bv/bv_inequality_graph.h index 6eb88ec79..57e59f6f5 100644 --- a/src/theory/bv/bv_inequality_graph.h +++ b/src/theory/bv/bv_inequality_graph.h @@ -24,7 +24,7 @@ #include "theory/uf/equality_engine.h" #include "theory/theory.h" #include <queue> - +#include <list> namespace CVC4 { namespace theory { @@ -37,7 +37,7 @@ extern const TermId UndefinedTermId; extern const ReasonId UndefinedReasonId; extern const ReasonId AxiomReasonId; -class InequalityGraph { +class InequalityGraph : public context::ContextNotifyObj{ context::Context* d_context; @@ -51,55 +51,55 @@ class InequalityGraph { reason(r), strict(s) {} + bool operator==(const InequalityEdge& other) const { + return next == other.next && reason == other.reason && strict == other.strict; + } }; class InequalityNode { TermId d_id; unsigned d_bitwidth; bool d_isConstant; - BitVector d_value; public: - InequalityNode(TermId id, unsigned bitwidth, bool isConst, const BitVector val) + InequalityNode(TermId id, unsigned bitwidth, bool isConst) : d_id(id), d_bitwidth(bitwidth), - d_isConstant(isConst), - d_value(val) {} + d_isConstant(isConst) + {} TermId getId() const { return d_id; } unsigned getBitwidth() const { return d_bitwidth; } bool isConstant() const { return d_isConstant; } - const BitVector& getValue() const { return d_value; } - void setValue(const BitVector& val) { Assert (val.getSize() == d_bitwidth); d_value = val; } }; - struct Explanation { + struct ModelValue { TermId parent; ReasonId reason; - - Explanation() + BitVector value; + ModelValue() : parent(UndefinedTermId), - reason(UndefinedReasonId) + reason(UndefinedReasonId), + value(0, 0u) {} - Explanation(TermId p, ReasonId r) + ModelValue(const BitVector& val, TermId p, ReasonId r) : parent(p), - reason(r) + reason(r), + value(val) {} }; struct PQueueElement { TermId id; - BitVector value; BitVector lower_bound; - Explanation explanation; - PQueueElement(TermId id, const BitVector v, const BitVector& lb, Explanation exp) + ModelValue model_value; + PQueueElement(TermId id, const BitVector& lb, const ModelValue& mv) : id(id), - value(v), lower_bound(lb), - explanation(exp) + model_value(mv) {} bool operator< (const PQueueElement& other) const { - return value > other.value; + return model_value.value > other.model_value.value; } std::string toString() const; }; @@ -111,7 +111,6 @@ class InequalityGraph { typedef __gnu_cxx::hash_set<TermId> TermIdSet; typedef std::priority_queue<PQueueElement> BFSQueue; - typedef __gnu_cxx::hash_map<TermId, Explanation> TermIdToExplanationMap; std::vector<InequalityNode> d_ineqNodes; std::vector< Edges > d_ineqEdges; @@ -122,13 +121,16 @@ class InequalityGraph { std::vector<Node> d_termNodes; TermNodeToIdMap d_termNodeToIdMap; - TermIdToExplanationMap d_termToExplanation; - context::CDO<bool> d_inConflict; std::vector<TNode> d_conflict; bool d_signed; - + context::CDHashMap<TermId, ModelValue> d_modelValues; + void setModelValue(TermId term, const ModelValue& mv); + ModelValue getModelValue(TermId term) const; + bool hasModelValue(TermId id) const; + bool hasReason(TermId id) const; + /** * Registers the term by creating its corresponding InequalityNode * and adding the min <= term <= max default edges. @@ -144,21 +146,18 @@ class InequalityGraph { ReasonId registerReason(TNode reason); TNode getReasonNode(ReasonId id) const; - bool hasExplanation(TermId id) const { return d_termToExplanation.find(id) != d_termToExplanation.end(); } - Explanation getExplanation(TermId id) const { Assert (hasExplanation(id)); return d_termToExplanation.find(id)->second; } - void setExplanation(TermId id, Explanation exp) { d_termToExplanation[id] = exp; } Edges& getEdges(TermId id) { Assert (id < d_ineqEdges.size()); return d_ineqEdges[id]; } InequalityNode& getInequalityNode(TermId id) { Assert (id < d_ineqNodes.size()); return d_ineqNodes[id]; } const InequalityNode& getInequalityNode(TermId id) const { Assert (id < d_ineqNodes.size()); return d_ineqNodes[id]; } unsigned getBitwidth(TermId id) const { return getInequalityNode(id).getBitwidth(); } bool isConst(TermId id) const { return getInequalityNode(id).isConstant(); } - BitVector maxValue(unsigned bitwidth); - BitVector minValue(unsigned bitwidth); - TermId getMaxValueId(unsigned bitwidth); - TermId getMinValueId(unsigned bitwidth); + // BitVector maxValue(unsigned bitwidth); + // BitVector minValue(unsigned bitwidth); + // TermId getMaxValueId(unsigned bitwidth); + // TermId getMinValueId(unsigned bitwidth); - const BitVector& getValue(TermId id) const { return getInequalityNode(id).getValue(); } + BitVector getValue(TermId id) const; void addEdge(TermId a, TermId b, bool strict, TermId reason); @@ -192,16 +191,30 @@ class InequalityGraph { * @param explanation */ void computeExplanation(TermId from, TermId to, std::vector<ReasonId>& explanation); + + /** Backtracking mechanisms **/ + std::vector<std::pair<TermId, InequalityEdge> > d_undoStack; + context::CDO<unsigned> d_undoStackIndex; + void contextNotifyPop() { + backtrack(); + } + + void backtrack(); + public: InequalityGraph(context::Context* c, bool s = false) - : d_context(c), + : ContextNotifyObj(c), + d_context(c), d_ineqNodes(), d_ineqEdges(), d_inConflict(c, false), d_conflict(), - d_signed(s) + d_signed(s), + d_modelValues(c), + d_undoStack(), + d_undoStackIndex(c) {} /** * @@ -216,6 +229,7 @@ public: bool addInequality(TNode a, TNode b, bool strict, TNode reason); bool areLessThan(TNode a, TNode b); void getConflict(std::vector<TNode>& conflict); + virtual ~InequalityGraph() throw(AssertionException) {} }; } |