diff options
Diffstat (limited to 'src/theory/bv/bv_inequality_graph.h')
-rw-r--r-- | src/theory/bv/bv_inequality_graph.h | 39 |
1 files changed, 30 insertions, 9 deletions
diff --git a/src/theory/bv/bv_inequality_graph.h b/src/theory/bv/bv_inequality_graph.h index 2c7d3f8a3..1a4b14ace 100644 --- a/src/theory/bv/bv_inequality_graph.h +++ b/src/theory/bv/bv_inequality_graph.h @@ -23,6 +23,7 @@ #include "context/cdqueue.h" #include "theory/uf/equality_engine.h" #include "theory/theory.h" +#include <queue> namespace CVC4 { namespace theory { @@ -32,7 +33,8 @@ namespace bv { typedef unsigned TermId; typedef unsigned ReasonId; - +extern const TermId UndefinedTermId; +extern const ReasonId UndefinedReasonId; class InequalityGraph { @@ -54,16 +56,17 @@ class InequalityGraph { bool d_isConstant; BitVector d_value; public: - InequalityNode(TermId id, unsigned bitwidth, bool isConst = false) + InequalityNode(TermId id, unsigned bitwidth, bool isConst, BitVector val) : d_id(id), d_bitwidth(bitwidth), d_isConstant(isConst), - d_value(BitVector(bitwidth, 0u)) + d_value(val) {} TermId getId() const { return d_id; } unsigned getBitwidth() const { return d_bitwidth; } bool isConstant() const { return d_isConstant; } const BitVector& getValue() const { return d_value; } + void setValue(const BitVector& val) { Assert (val.getSize() == d_bitwidth); d_value = val; } }; struct PQueueElement { @@ -84,7 +87,8 @@ class InequalityGraph { typedef std::vector<InequalityEdge> Edges; typedef __gnu_cxx::hash_set<TermId> TermIdSet; - typedef std::queue<PQueueElement> BFSQueue; + typedef std::queue<TermId> TermIdQueue; + typedef std::priority_queue<PQueueElement> BFSQueue; std::vector<InequalityNode> d_ineqNodes; @@ -100,19 +104,36 @@ class InequalityGraph { TermId registerTerm(TNode term); ReasonId registerReason(TNode reason); TNode getReason(ReasonId id) const; + TermId getReasonId(TermId a, TermId b); TNode getTerm(TermId id) const; Edges& getOutEdges(TermId id) { Assert (id < d_ineqEdges.size()); return d_ineqEdges[id]; } Edges& getInEdges(TermId id) { Assert (id < d_parentEdges.size()); return d_parentEdges[id]; } InequalityNode& getInequalityNode(TermId id) { Assert (id < d_ineqNodes.size()); return d_ineqNodes[id]; } - const BitVector& getValue(TermId id) const { return getInequalityNode().getValue(); } + const InequalityNode& getInequalityNode(TermId id) const { Assert (id < d_ineqNodes.size()); return d_ineqNodes[id]; } + + const BitVector& getValue(TermId id) const { return getInequalityNode(id).getValue(); } + unsigned getBitwidth(TermId id) const { return getInequalityNode(id).getBitwidth(); } + + bool hasValue(TermId id) const; + bool initializeValues(TNode a, TNode b, TermId reason_id); + TermId getMaxParent(TermId id); + bool hasParents(TermId id); + + bool canReach(TermId from, TermId to); + void bfs(TermIdSet& seen, TermIdQueue& queue); + + void addEdge(TermId a, TermId b, TermId reason); bool addInequalityInternal(TermId a, TermId b, ReasonId reason); bool areLessThanInternal(TermId a, TermId b); void getConflictInternal(std::vector<ReasonId>& conflict); - + bool computeValuesBFS(BFSQueue& queue, TermIdSet& seen); + void computeExplanation(TermId from, TermId to, std::vector<ReasonId>& explanation); + context::CDO<bool> d_inConflict; - context::CDList<TNode> d_conflict; - void setConflict(const std::vector<ReasonId>& conflict); + std::vector<TNode> d_conflict; + void setConflict(const std::vector<ReasonId>& conflict); + public: InequalityGraph(context::Context* c) @@ -121,7 +142,7 @@ public: d_ineqEdges(), d_parentEdges(), d_inConflict(c, false), - d_conflict(c) + d_conflict() {} bool addInequality(TNode a, TNode b, TNode reason); bool areLessThan(TNode a, TNode b); |