From 25ac2c8f4b45e2b299895e97a30790fbf46cf79f Mon Sep 17 00:00:00 2001 From: lianah Date: Sat, 16 Mar 2013 15:48:51 -0400 Subject: started work on the inequality bv subtheory --- src/theory/bv/slicer.h | 176 +++++++++++++++++++++++++++++-------------------- 1 file changed, 103 insertions(+), 73 deletions(-) (limited to 'src/theory/bv/slicer.h') diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h index 6e09d971b..6bbe2f467 100644 --- a/src/theory/bv/slicer.h +++ b/src/theory/bv/slicer.h @@ -32,7 +32,7 @@ #include "context/context.h" #include "context/cdhashset.h" #include "context/cdo.h" - +#include "context/cdqueue.h" #ifndef __CVC4__THEORY__BV__SLICER_BV_H #define __CVC4__THEORY__BV__SLICER_BV_H @@ -45,6 +45,7 @@ namespace bv { typedef Index TermId; +typedef TermId ExplanationId; extern const TermId UndefinedId; class CDBase; @@ -81,49 +82,9 @@ public: } return true; } - friend class CDBase; }; -class CDBase : public context::ContextNotifyObj { - context::Context* d_ctx; - context::CDO d_undoIndex; - - std::vector d_undoStack; - Base d_base; - CDBase(context::Context* ctx, Index bitwidth) - : ContextNotifyObj(ctx), - d_ctx(ctx), - d_undoIndex(d_ctx), - d_undoStack(), - d_base(bitwidth) - {} - void sliceAt(Index i) { - Assert (!d_base.isCutPoint(i)); - d_undoStack.push_back(i); - d_undoIndex.set(d_undoIndex.get() + 1); - d_base.sliceAt(i); - } - bool isCutPoint(Index i) { - return d_base.isCutPoint(i); - } - Index getBitwidth() const {return d_base.getBitwidth(); } - virtual ~CDBase() throw(AssertionException) {} - void contextNotifyPop() { - backtrack(); - } - - void backtrack() { - for (unsigned i = d_undoIndex.get(); i < d_undoStack.size(); ++i) { - Index i = d_undoStack.back(); - d_undoStack.pop_back(); - d_base.undoSliceAt(i); - } - Assert(d_undoIndex.get() == d_undoStack.size()); - } - -}; - /** * UnionFind * @@ -135,6 +96,11 @@ struct ExtractTerm { TermId id; Index high; Index low; + ExtractTerm() + : id (UndefinedId), + high(UndefinedId), + low(UndefinedId) + {} ExtractTerm(TermId i, Index h, Index l) : id (i), high(h), @@ -142,10 +108,24 @@ struct ExtractTerm { { Assert (h >= l && id != UndefinedId); } + bool operator== (const ExtractTerm& other) const { + return id == other.id && high == other.high && low == other.low; + } Index getBitwidth() const { return high - low + 1; } std::string debugPrint() const; + friend class ExtractTermHashFunction; }; +struct ExtractTermHashFunction { + ::std::size_t operator() (const ExtractTerm& t) const { + __gnu_cxx::hash h; + unsigned id = t.id; + unsigned high = t.high; + unsigned low = t.low; + return (h(id) * 7919 + h(high))* 4391 + h(low); + } +}; + class UnionFind; struct NormalForm { @@ -168,21 +148,34 @@ struct NormalForm { void clear() { base.clear(); decomp.clear(); } }; +class Slicer; class UnionFind : public context::ContextNotifyObj { + + struct ReprEdge { + TermId repr; + ExplanationId reason; + ReprEdge() + : repr(UndefinedId), + reason(UndefinedId) + {} + }; + class Node { - Index d_bitwidth; - TermId d_ch1, d_ch0; - TermId d_repr; + Index d_bitwidth; + TermId d_ch1, d_ch0; // the ids of the two children if they exist + ReprEdge d_edge; // points to the representative and stores the explanation + public: Node(Index b) : d_bitwidth(b), d_ch1(UndefinedId), d_ch0(UndefinedId), - d_repr(UndefinedId) + d_edge() {} - - TermId getRepr() const { return d_repr; } + + TermId getRepr() const { return d_edge.repr; } + ExplanationId getReason() const { return d_edge.reason; } Index getBitwidth() const { return d_bitwidth; } bool hasChildren() const { return d_ch1 != UndefinedId && d_ch0 != UndefinedId; } @@ -190,9 +183,10 @@ class UnionFind : public context::ContextNotifyObj { Assert (i < 2); return i == 0? d_ch0 : d_ch1; } - void setRepr(TermId id) { + void setRepr(TermId repr, ExplanationId reason) { Assert (! hasChildren()); - d_repr = id; + d_edge.repr = repr; + d_edge.reason = reason; } void setChildren(TermId ch1, TermId ch0) { // Assert (d_repr == UndefinedId && !hasChildren()); @@ -204,16 +198,21 @@ class UnionFind : public context::ContextNotifyObj { /// map from TermId to the nodes that represent them std::vector d_nodes; - /// a term is in this set if it is its own representative - //CDTermSet d_representatives; + __gnu_cxx::hash_map > d_idToExtract; + __gnu_cxx::hash_map d_extractToId; void getDecomposition(const ExtractTerm& term, Decomposition& decomp); + void getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, std::vector& explanation); void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common); /// getter methods for the internal nodes TermId getRepr(TermId id) const { Assert (id < d_nodes.size()); return d_nodes[id].getRepr(); } + ExplanationId getReason(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getReason(); + } TermId getChild(TermId id, Index i) const { Assert (id < d_nodes.size()); return d_nodes[id].getChild(i); @@ -225,10 +224,12 @@ class UnionFind : public context::ContextNotifyObj { Assert (id < d_nodes.size()); return d_nodes[id].hasChildren(); } + TermId getTopLevel(TermId id) const; + /// setter methods for the internal nodes - void setRepr(TermId id, TermId new_repr) { + void setRepr(TermId id, TermId new_repr, ExplanationId reason) { Assert (id < d_nodes.size()); - d_nodes[id].setRepr(new_repr); + d_nodes[id].setRepr(new_repr, reason); } void setChildren(TermId id, TermId ch1, TermId ch0) { Assert ((ch1 == UndefinedId && ch0 == UndefinedId) || @@ -269,25 +270,32 @@ class UnionFind : public context::ContextNotifyObj { Statistics(); ~Statistics(); }; - Statistics d_statistics; - bool d_newSplit; + Slicer* d_slicer; public: - UnionFind(context::Context* ctx) + UnionFind(context::Context* ctx, Slicer* slicer) : ContextNotifyObj(ctx), d_nodes(), + d_idToExtract(), + d_extractToId(), d_undoStack(), d_undoStackIndex(ctx), - d_statistics() + d_statistics(), + d_slicer(slicer) {} - TermId addTerm(Index bitwidth); - void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2); - void merge(TermId t1, TermId t2); - TermId find(TermId t1); + TermId addNode(Index bitwidth); + TermId addExtract(Index topLevel, Index high, Index low); + ExtractTerm getExtractTerm(TermId id) const; + bool isExtractTerm(TermId id) const; + + void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason); + void merge(TermId t1, TermId t2, TermId reason); + TermId find(TermId t1); + TermId findWithExplanation(TermId id, std::vector& explanation); void split(TermId term, Index i); - void getNormalForm(const ExtractTerm& term, NormalForm& nf); + void getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, std::vector& explanation); void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2); void ensureSlicing(const ExtractTerm& term); Index getBitwidth(TermId id) const { @@ -300,34 +308,56 @@ public: void contextNotifyPop() { backtrack(); } - bool hasNewSplit() { return d_newSplit; } - void resetNewSplit() { d_newSplit = false; } - friend class Slicer; }; +class CoreSolver; + class Slicer { - __gnu_cxx::hash_map d_idToNode; + __gnu_cxx::hash_map > d_idToNode; __gnu_cxx::hash_map d_nodeToId; __gnu_cxx::hash_map d_coreTermCache; + __gnu_cxx::hash_map d_explanationToId; + std::vector d_explanations; UnionFind d_unionFind; - ExtractTerm registerTerm(TNode node); + + context::CDQueue d_newSplits; + context::CDO d_newSplitsIndex; + CoreSolver* d_coreSolver; + TermId d_termIdCount; public: - Slicer(context::Context* ctx) + Slicer(context::Context* ctx, CoreSolver* coreSolver) : d_idToNode(), d_nodeToId(), d_coreTermCache(), - d_unionFind(ctx) + d_explanationToId(), + d_explanations(), + d_unionFind(ctx, this), + d_newSplits(ctx), + d_newSplitsIndex(ctx, 0), + d_coreSolver(coreSolver) {} - void getBaseDecomposition(TNode node, std::vector& decomp); + void getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation); + void registerEquality(TNode eq); + ExtractTerm registerTerm(TNode node); void processEquality(TNode eq); + void assertEquality(TNode eq); bool isCoreTerm (TNode node); - Base getTopLevelBase(TNode node); + + bool hasNode(TermId id) const; + Node getNode(TermId id) const; + TermId getId(TNode node) const; + + bool hasExplanation(ExplanationId id) const; + TNode getExplanation(ExplanationId id) const; + ExplanationId getExplanationId(TNode reason) const; + + bool termInEqualityEngine(TermId id); + void enqueueSplit(TermId id, Index i); + void getNewSplits(std::vector& splits); static void splitEqualities(TNode node, std::vector& equalities); static unsigned d_numAddedEqualities; - inline bool hasNewSplit() { return d_unionFind.hasNewSplit(); } - inline void resetNewSplit() { d_unionFind.resetNewSplit(); } }; -- cgit v1.2.3