diff options
Diffstat (limited to 'src/theory/bv/slicer.h')
-rw-r--r-- | src/theory/bv/slicer.h | 236 |
1 files changed, 66 insertions, 170 deletions
diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h index c46ef99ed..88254b983 100644 --- a/src/theory/bv/slicer.h +++ b/src/theory/bv/slicer.h @@ -29,10 +29,6 @@ #include "util/index.h" #include "expr/node.h" #include "theory/bv/theory_bv_utils.h" -#include "context/context.h" -#include "context/cdhashset.h" -#include "context/cdo.h" -#include "context/cdqueue.h" #ifndef __CVC4__THEORY__BV__SLICER_BV_H #define __CVC4__THEORY__BV__SLICER_BV_H @@ -45,10 +41,8 @@ namespace bv { typedef Index TermId; -typedef TermId ExplanationId; extern const TermId UndefinedId; -class CDBase; /** * Base @@ -57,11 +51,9 @@ class CDBase; class Base { Index d_size; std::vector<uint32_t> d_repr; - void undoSliceAt(Index index); public: - Base (Index size); - void sliceAt(Index index); - + Base(Index size); + void sliceAt(Index index); void sliceWith(const Base& other); bool isCutPoint(Index index) const; void diffCutPoints(const Base& other, Base& res) const; @@ -84,23 +76,17 @@ public: } }; - /** * UnionFind * */ -typedef context::CDHashSet<uint32_t, std::hash<uint32_t> > CDTermSet; +typedef __gnu_cxx::hash_set<TermId> TermSet; typedef std::vector<TermId> Decomposition; struct ExtractTerm { TermId id; Index high; Index low; - ExtractTerm() - : id (UndefinedId), - high(UndefinedId), - low(UndefinedId) - {} ExtractTerm(TermId i, Index h, Index l) : id (i), high(h), @@ -108,24 +94,10 @@ struct ExtractTerm { { Assert (h >= l && id != UndefinedId); } - bool operator== (const ExtractTerm& other) const { - return id == other.id && high == other.high && low == other.low; - } Index getBitwidth() const { return high - low + 1; } std::string debugPrint() const; - friend class ExtractTermHashFunction; }; -struct ExtractTermHashFunction { - ::std::size_t operator() (const ExtractTerm& t) const { - __gnu_cxx::hash<unsigned> h; - unsigned id = t.id; - unsigned high = t.high; - unsigned low = t.low; - return (h(id) * 7919 + h(high))* 4391 + h(low); - } -}; - class UnionFind; struct NormalForm { @@ -148,34 +120,21 @@ struct NormalForm { void clear() { base.clear(); decomp.clear(); } }; -class Slicer; - -class UnionFind : public context::ContextNotifyObj { - struct ReprEdge { - TermId repr; - ExplanationId reason; - ReprEdge() - : repr(UndefinedId), - reason(UndefinedId) - {} - }; - - class EqualityNode { - Index d_bitwidth; - TermId d_ch1, d_ch0; // the ids of the two children if they exist - ReprEdge d_edge; // points to the representative and stores the explanation - +class UnionFind { + class Node { + Index d_bitwidth; + TermId d_ch1, d_ch0; + TermId d_repr; public: - EqualityNode(Index b) + Node(Index b) : d_bitwidth(b), d_ch1(UndefinedId), d_ch0(UndefinedId), - d_edge() + d_repr(UndefinedId) {} - - TermId getRepr() const { return d_edge.repr; } - ExplanationId getReason() const { return d_edge.reason; } + + TermId getRepr() const { return d_repr; } Index getBitwidth() const { return d_bitwidth; } bool hasChildren() const { return d_ch1 != UndefinedId && d_ch0 != UndefinedId; } @@ -183,64 +142,51 @@ class UnionFind : public context::ContextNotifyObj { Assert (i < 2); return i == 0? d_ch0 : d_ch1; } - void setRepr(TermId repr, ExplanationId reason) { + void setRepr(TermId id) { Assert (! hasChildren()); - d_edge.repr = repr; - d_edge.reason = reason; + d_repr = id; } void setChildren(TermId ch1, TermId ch0) { + Assert (d_repr == UndefinedId && !hasChildren()); d_ch1 = ch1; d_ch0 = ch0; } std::string debugPrint() const; }; - - // the equality nodes in the union find - std::vector<EqualityNode> d_equalityNodes; - - /// getter methods for the internal nodes - TermId getRepr(TermId id) const; - ExplanationId getReason(TermId id) const; - TermId getChild(TermId id, Index i) const; - Index getCutPoint(TermId id) const; - bool hasChildren(TermId id) const; - /// setter methods for the internal nodes - void setRepr(TermId id, TermId new_repr, ExplanationId reason); - void setChildren(TermId id, TermId ch1, TermId ch0); - - // the mappings between ExtractTerms and ids - __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> > d_idToExtract; - __gnu_cxx::hash_map<ExtractTerm, TermId, ExtractTermHashFunction > d_extractToId; - - __gnu_cxx::hash_set<TermId> d_topLevelIds; + /// map from TermId to the nodes that represent them + std::vector<Node> d_nodes; + /// a term is in this set if it is its own representative + TermSet d_representatives; void getDecomposition(const ExtractTerm& term, Decomposition& decomp); - void getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, std::vector<ExplanationId>& explanation); void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common); - - /* Backtracking mechanisms */ - - enum OperationKind { - MERGE, - SPLIT - }; - - struct Operation { - OperationKind op; - TermId id; - Operation(OperationKind o, TermId i) - : op(o), id(i) {} - }; - - std::vector<Operation> d_undoStack; - context::CDO<unsigned> d_undoStackIndex; + /// getter methods for the internal nodes + TermId getRepr(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getRepr(); + } + TermId getChild(TermId id, Index i) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getChild(i); + } + Index getCutPoint(TermId id) const { + return getBitwidth(getChild(id, 0)); + } + bool hasChildren(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].hasChildren(); + } + /// setter methods for the internal nodes + void setRepr(TermId id, TermId new_repr) { + Assert (id < d_nodes.size()); + d_nodes[id].setRepr(new_repr); + } + void setChildren(TermId id, TermId ch1, TermId ch0) { + Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0)); + d_nodes[id].setChildren(ch1, ch0); + } - void backtrack(); - void undoMerge(TermId id); - void undoSplit(TermId id); - void recordOperation(OperationKind op, TermId term); - virtual ~UnionFind() throw(AssertionException) {} class Statistics { public: IntStat d_numNodes; @@ -249,106 +195,56 @@ class UnionFind : public context::ContextNotifyObj { IntStat d_numMerges; AverageStat d_avgFindDepth; ReferenceStat<unsigned> d_numAddedEqualities; + //IntStat d_numAddedEqualities; Statistics(); ~Statistics(); }; - Statistics d_statistics; - Slicer* d_slicer; - TermId d_termIdCount; - TermId mkEqualityNode(Index bitwidth); - ExtractTerm mkExtractTerm(TermId id, Index high, Index low); - void storeExtractTerm(Index id, const ExtractTerm& term); - ExtractTerm getExtractTerm(TermId id) const; - bool extractHasId(const ExtractTerm& ex) const { return d_extractToId.find(ex) != d_extractToId.end(); } - TermId getExtractId(const ExtractTerm& ex) const {Assert (extractHasId(ex)); return d_extractToId.find(ex)->second; } - bool isExtractTerm(TermId id) const; + Statistics d_statistics +; + public: - UnionFind(context::Context* ctx, Slicer* slicer) - : ContextNotifyObj(ctx), - d_equalityNodes(), - d_idToExtract(), - d_extractToId(), - d_topLevelIds(), - d_undoStack(), - d_undoStackIndex(ctx), - d_statistics(), - d_slicer(slicer), - d_termIdCount(0) + UnionFind() + : d_nodes(), + d_representatives() {} - TermId addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low); - TermId registerTopLevelTerm(Index bitwidth); - void unionTerms(TermId id1, TermId id2, TermId reason); - void merge(TermId t1, TermId t2, TermId reason); - TermId find(TermId t1); - TermId findWithExplanation(TermId id, std::vector<ExplanationId>& explanation); + TermId addTerm(Index bitwidth); + void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2); + void merge(TermId t1, TermId t2); + TermId find(TermId t1); void split(TermId term, Index i); + void getNormalForm(const ExtractTerm& term, NormalForm& nf); - void getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, std::vector<ExplanationId>& explanation); - void alignSlicings(TermId id1, TermId id2); - void ensureSlicing(TermId id); + void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2); + void ensureSlicing(const ExtractTerm& term); Index getBitwidth(TermId id) const { - Assert (id < d_equalityNodes.size()); - return d_equalityNodes[id].getBitwidth(); + Assert (id < d_nodes.size()); + return d_nodes[id].getBitwidth(); } - void getBase(TermId id, Base& base, Index offset); std::string debugPrint(TermId id); - - void contextNotifyPop() { - backtrack(); - } friend class Slicer; }; -class CoreSolver; - class Slicer { - __gnu_cxx::hash_map<TermId, TNode> d_idToNode; + __gnu_cxx::hash_map<TermId, TNode> d_idToNode; __gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction> d_nodeToId; __gnu_cxx::hash_map<TNode, bool, TNodeHashFunction> d_coreTermCache; - __gnu_cxx::hash_map<TNode, ExplanationId, NodeHashFunction> d_explanationToId; - std::vector<TNode> d_explanations; UnionFind d_unionFind; - - context::CDQueue<Node> d_newSplits; - context::CDO<unsigned> d_newSplitsIndex; - CoreSolver* d_coreSolver; - TermId registerTopLevelTerm(TNode node); - bool isTopLevelNode(TermId id) const; - TermId registerTerm(TNode node); + ExtractTerm registerTerm(TNode node); public: - Slicer(context::Context* ctx, CoreSolver* coreSolver) + Slicer() : d_idToNode(), d_nodeToId(), d_coreTermCache(), - d_explanationToId(), - d_explanations(), - d_unionFind(ctx, this), - d_newSplits(ctx), - d_newSplitsIndex(ctx, 0), - d_coreSolver(coreSolver) + d_unionFind() {} - void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation); - void registerEquality(TNode eq); - + void getBaseDecomposition(TNode node, std::vector<Node>& decomp); void processEquality(TNode eq); - void assertEquality(TNode eq); bool isCoreTerm (TNode node); - - bool hasNode(TermId id) const; - Node getNode(TermId id) const; - - bool hasExplanation(ExplanationId id) const; - TNode getExplanation(ExplanationId id) const; - ExplanationId getExplanationId(TNode reason) const; - - bool termInEqualityEngine(TermId id); - void enqueueSplit(TermId id, Index i, TermId top, TermId bottom); - void getNewSplits(std::vector<Node>& splits); static void splitEqualities(TNode node, std::vector<Node>& equalities); - static unsigned d_numAddedEqualities; + static unsigned d_numAddedEqualities; }; |