summaryrefslogtreecommitdiff
path: root/src/theory/bv/slicer.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/bv/slicer.h')
-rw-r--r--src/theory/bv/slicer.h176
1 files changed, 103 insertions, 73 deletions
diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h
index 6e09d971b..6bbe2f467 100644
--- a/src/theory/bv/slicer.h
+++ b/src/theory/bv/slicer.h
@@ -32,7 +32,7 @@
#include "context/context.h"
#include "context/cdhashset.h"
#include "context/cdo.h"
-
+#include "context/cdqueue.h"
#ifndef __CVC4__THEORY__BV__SLICER_BV_H
#define __CVC4__THEORY__BV__SLICER_BV_H
@@ -45,6 +45,7 @@ namespace bv {
typedef Index TermId;
+typedef TermId ExplanationId;
extern const TermId UndefinedId;
class CDBase;
@@ -81,49 +82,9 @@ public:
}
return true;
}
- friend class CDBase;
};
-class CDBase : public context::ContextNotifyObj {
- context::Context* d_ctx;
- context::CDO<unsigned> d_undoIndex;
-
- std::vector<unsigned> d_undoStack;
- Base d_base;
- CDBase(context::Context* ctx, Index bitwidth)
- : ContextNotifyObj(ctx),
- d_ctx(ctx),
- d_undoIndex(d_ctx),
- d_undoStack(),
- d_base(bitwidth)
- {}
- void sliceAt(Index i) {
- Assert (!d_base.isCutPoint(i));
- d_undoStack.push_back(i);
- d_undoIndex.set(d_undoIndex.get() + 1);
- d_base.sliceAt(i);
- }
- bool isCutPoint(Index i) {
- return d_base.isCutPoint(i);
- }
- Index getBitwidth() const {return d_base.getBitwidth(); }
- virtual ~CDBase() throw(AssertionException) {}
- void contextNotifyPop() {
- backtrack();
- }
-
- void backtrack() {
- for (unsigned i = d_undoIndex.get(); i < d_undoStack.size(); ++i) {
- Index i = d_undoStack.back();
- d_undoStack.pop_back();
- d_base.undoSliceAt(i);
- }
- Assert(d_undoIndex.get() == d_undoStack.size());
- }
-
-};
-
/**
* UnionFind
*
@@ -135,6 +96,11 @@ struct ExtractTerm {
TermId id;
Index high;
Index low;
+ ExtractTerm()
+ : id (UndefinedId),
+ high(UndefinedId),
+ low(UndefinedId)
+ {}
ExtractTerm(TermId i, Index h, Index l)
: id (i),
high(h),
@@ -142,10 +108,24 @@ struct ExtractTerm {
{
Assert (h >= l && id != UndefinedId);
}
+ bool operator== (const ExtractTerm& other) const {
+ return id == other.id && high == other.high && low == other.low;
+ }
Index getBitwidth() const { return high - low + 1; }
std::string debugPrint() const;
+ friend class ExtractTermHashFunction;
};
+struct ExtractTermHashFunction {
+ ::std::size_t operator() (const ExtractTerm& t) const {
+ __gnu_cxx::hash<unsigned> h;
+ unsigned id = t.id;
+ unsigned high = t.high;
+ unsigned low = t.low;
+ return (h(id) * 7919 + h(high))* 4391 + h(low);
+ }
+};
+
class UnionFind;
struct NormalForm {
@@ -168,21 +148,34 @@ struct NormalForm {
void clear() { base.clear(); decomp.clear(); }
};
+class Slicer;
class UnionFind : public context::ContextNotifyObj {
+
+ struct ReprEdge {
+ TermId repr;
+ ExplanationId reason;
+ ReprEdge()
+ : repr(UndefinedId),
+ reason(UndefinedId)
+ {}
+ };
+
class Node {
- Index d_bitwidth;
- TermId d_ch1, d_ch0;
- TermId d_repr;
+ Index d_bitwidth;
+ TermId d_ch1, d_ch0; // the ids of the two children if they exist
+ ReprEdge d_edge; // points to the representative and stores the explanation
+
public:
Node(Index b)
: d_bitwidth(b),
d_ch1(UndefinedId),
d_ch0(UndefinedId),
- d_repr(UndefinedId)
+ d_edge()
{}
-
- TermId getRepr() const { return d_repr; }
+
+ TermId getRepr() const { return d_edge.repr; }
+ ExplanationId getReason() const { return d_edge.reason; }
Index getBitwidth() const { return d_bitwidth; }
bool hasChildren() const { return d_ch1 != UndefinedId && d_ch0 != UndefinedId; }
@@ -190,9 +183,10 @@ class UnionFind : public context::ContextNotifyObj {
Assert (i < 2);
return i == 0? d_ch0 : d_ch1;
}
- void setRepr(TermId id) {
+ void setRepr(TermId repr, ExplanationId reason) {
Assert (! hasChildren());
- d_repr = id;
+ d_edge.repr = repr;
+ d_edge.reason = reason;
}
void setChildren(TermId ch1, TermId ch0) {
// Assert (d_repr == UndefinedId && !hasChildren());
@@ -204,16 +198,21 @@ class UnionFind : public context::ContextNotifyObj {
/// map from TermId to the nodes that represent them
std::vector<Node> d_nodes;
- /// a term is in this set if it is its own representative
- //CDTermSet d_representatives;
+ __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> > d_idToExtract;
+ __gnu_cxx::hash_map<ExtractTerm, TermId, ExtractTermHashFunction > d_extractToId;
void getDecomposition(const ExtractTerm& term, Decomposition& decomp);
+ void getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, std::vector<ExplanationId>& explanation);
void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common);
/// getter methods for the internal nodes
TermId getRepr(TermId id) const {
Assert (id < d_nodes.size());
return d_nodes[id].getRepr();
}
+ ExplanationId getReason(TermId id) const {
+ Assert (id < d_nodes.size());
+ return d_nodes[id].getReason();
+ }
TermId getChild(TermId id, Index i) const {
Assert (id < d_nodes.size());
return d_nodes[id].getChild(i);
@@ -225,10 +224,12 @@ class UnionFind : public context::ContextNotifyObj {
Assert (id < d_nodes.size());
return d_nodes[id].hasChildren();
}
+ TermId getTopLevel(TermId id) const;
+
/// setter methods for the internal nodes
- void setRepr(TermId id, TermId new_repr) {
+ void setRepr(TermId id, TermId new_repr, ExplanationId reason) {
Assert (id < d_nodes.size());
- d_nodes[id].setRepr(new_repr);
+ d_nodes[id].setRepr(new_repr, reason);
}
void setChildren(TermId id, TermId ch1, TermId ch0) {
Assert ((ch1 == UndefinedId && ch0 == UndefinedId) ||
@@ -269,25 +270,32 @@ class UnionFind : public context::ContextNotifyObj {
Statistics();
~Statistics();
};
-
Statistics d_statistics;
- bool d_newSplit;
+ Slicer* d_slicer;
public:
- UnionFind(context::Context* ctx)
+ UnionFind(context::Context* ctx, Slicer* slicer)
: ContextNotifyObj(ctx),
d_nodes(),
+ d_idToExtract(),
+ d_extractToId(),
d_undoStack(),
d_undoStackIndex(ctx),
- d_statistics()
+ d_statistics(),
+ d_slicer(slicer)
{}
- TermId addTerm(Index bitwidth);
- void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2);
- void merge(TermId t1, TermId t2);
- TermId find(TermId t1);
+ TermId addNode(Index bitwidth);
+ TermId addExtract(Index topLevel, Index high, Index low);
+ ExtractTerm getExtractTerm(TermId id) const;
+ bool isExtractTerm(TermId id) const;
+
+ void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason);
+ void merge(TermId t1, TermId t2, TermId reason);
+ TermId find(TermId t1);
+ TermId findWithExplanation(TermId id, std::vector<ExplanationId>& explanation);
void split(TermId term, Index i);
-
void getNormalForm(const ExtractTerm& term, NormalForm& nf);
+ void getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, std::vector<ExplanationId>& explanation);
void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2);
void ensureSlicing(const ExtractTerm& term);
Index getBitwidth(TermId id) const {
@@ -300,34 +308,56 @@ public:
void contextNotifyPop() {
backtrack();
}
- bool hasNewSplit() { return d_newSplit; }
- void resetNewSplit() { d_newSplit = false; }
-
friend class Slicer;
};
+class CoreSolver;
+
class Slicer {
- __gnu_cxx::hash_map<TermId, TNode> d_idToNode;
+ __gnu_cxx::hash_map<TermId, TNode, __gnu_cxx::hash<TermId> > d_idToNode;
__gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction> d_nodeToId;
__gnu_cxx::hash_map<TNode, bool, TNodeHashFunction> d_coreTermCache;
+ __gnu_cxx::hash_map<TNode, ExplanationId, TNodeHashFunction> d_explanationToId;
+ std::vector<TNode> d_explanations;
UnionFind d_unionFind;
- ExtractTerm registerTerm(TNode node);
+
+ context::CDQueue<Node> d_newSplits;
+ context::CDO<unsigned> d_newSplitsIndex;
+ CoreSolver* d_coreSolver;
+ TermId d_termIdCount;
public:
- Slicer(context::Context* ctx)
+ Slicer(context::Context* ctx, CoreSolver* coreSolver)
: d_idToNode(),
d_nodeToId(),
d_coreTermCache(),
- d_unionFind(ctx)
+ d_explanationToId(),
+ d_explanations(),
+ d_unionFind(ctx, this),
+ d_newSplits(ctx),
+ d_newSplitsIndex(ctx, 0),
+ d_coreSolver(coreSolver)
{}
- void getBaseDecomposition(TNode node, std::vector<Node>& decomp);
+ void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation);
+ void registerEquality(TNode eq);
+ ExtractTerm registerTerm(TNode node);
void processEquality(TNode eq);
+ void assertEquality(TNode eq);
bool isCoreTerm (TNode node);
- Base getTopLevelBase(TNode node);
+
+ bool hasNode(TermId id) const;
+ Node getNode(TermId id) const;
+ TermId getId(TNode node) const;
+
+ bool hasExplanation(ExplanationId id) const;
+ TNode getExplanation(ExplanationId id) const;
+ ExplanationId getExplanationId(TNode reason) const;
+
+ bool termInEqualityEngine(TermId id);
+ void enqueueSplit(TermId id, Index i);
+ void getNewSplits(std::vector<Node>& splits);
static void splitEqualities(TNode node, std::vector<Node>& equalities);
static unsigned d_numAddedEqualities;
- inline bool hasNewSplit() { return d_unionFind.hasNewSplit(); }
- inline void resetNewSplit() { d_unionFind.resetNewSplit(); }
};
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback