summaryrefslogtreecommitdiff
path: root/src/theory/bv/slicer.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/bv/slicer.h')
-rw-r--r--src/theory/bv/slicer.h236
1 files changed, 66 insertions, 170 deletions
diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h
index c46ef99ed..88254b983 100644
--- a/src/theory/bv/slicer.h
+++ b/src/theory/bv/slicer.h
@@ -29,10 +29,6 @@
#include "util/index.h"
#include "expr/node.h"
#include "theory/bv/theory_bv_utils.h"
-#include "context/context.h"
-#include "context/cdhashset.h"
-#include "context/cdo.h"
-#include "context/cdqueue.h"
#ifndef __CVC4__THEORY__BV__SLICER_BV_H
#define __CVC4__THEORY__BV__SLICER_BV_H
@@ -45,10 +41,8 @@ namespace bv {
typedef Index TermId;
-typedef TermId ExplanationId;
extern const TermId UndefinedId;
-class CDBase;
/**
* Base
@@ -57,11 +51,9 @@ class CDBase;
class Base {
Index d_size;
std::vector<uint32_t> d_repr;
- void undoSliceAt(Index index);
public:
- Base (Index size);
- void sliceAt(Index index);
-
+ Base(Index size);
+ void sliceAt(Index index);
void sliceWith(const Base& other);
bool isCutPoint(Index index) const;
void diffCutPoints(const Base& other, Base& res) const;
@@ -84,23 +76,17 @@ public:
}
};
-
/**
* UnionFind
*
*/
-typedef context::CDHashSet<uint32_t, std::hash<uint32_t> > CDTermSet;
+typedef __gnu_cxx::hash_set<TermId> TermSet;
typedef std::vector<TermId> Decomposition;
struct ExtractTerm {
TermId id;
Index high;
Index low;
- ExtractTerm()
- : id (UndefinedId),
- high(UndefinedId),
- low(UndefinedId)
- {}
ExtractTerm(TermId i, Index h, Index l)
: id (i),
high(h),
@@ -108,24 +94,10 @@ struct ExtractTerm {
{
Assert (h >= l && id != UndefinedId);
}
- bool operator== (const ExtractTerm& other) const {
- return id == other.id && high == other.high && low == other.low;
- }
Index getBitwidth() const { return high - low + 1; }
std::string debugPrint() const;
- friend class ExtractTermHashFunction;
};
-struct ExtractTermHashFunction {
- ::std::size_t operator() (const ExtractTerm& t) const {
- __gnu_cxx::hash<unsigned> h;
- unsigned id = t.id;
- unsigned high = t.high;
- unsigned low = t.low;
- return (h(id) * 7919 + h(high))* 4391 + h(low);
- }
-};
-
class UnionFind;
struct NormalForm {
@@ -148,34 +120,21 @@ struct NormalForm {
void clear() { base.clear(); decomp.clear(); }
};
-class Slicer;
-
-class UnionFind : public context::ContextNotifyObj {
- struct ReprEdge {
- TermId repr;
- ExplanationId reason;
- ReprEdge()
- : repr(UndefinedId),
- reason(UndefinedId)
- {}
- };
-
- class EqualityNode {
- Index d_bitwidth;
- TermId d_ch1, d_ch0; // the ids of the two children if they exist
- ReprEdge d_edge; // points to the representative and stores the explanation
-
+class UnionFind {
+ class Node {
+ Index d_bitwidth;
+ TermId d_ch1, d_ch0;
+ TermId d_repr;
public:
- EqualityNode(Index b)
+ Node(Index b)
: d_bitwidth(b),
d_ch1(UndefinedId),
d_ch0(UndefinedId),
- d_edge()
+ d_repr(UndefinedId)
{}
-
- TermId getRepr() const { return d_edge.repr; }
- ExplanationId getReason() const { return d_edge.reason; }
+
+ TermId getRepr() const { return d_repr; }
Index getBitwidth() const { return d_bitwidth; }
bool hasChildren() const { return d_ch1 != UndefinedId && d_ch0 != UndefinedId; }
@@ -183,64 +142,51 @@ class UnionFind : public context::ContextNotifyObj {
Assert (i < 2);
return i == 0? d_ch0 : d_ch1;
}
- void setRepr(TermId repr, ExplanationId reason) {
+ void setRepr(TermId id) {
Assert (! hasChildren());
- d_edge.repr = repr;
- d_edge.reason = reason;
+ d_repr = id;
}
void setChildren(TermId ch1, TermId ch0) {
+ Assert (d_repr == UndefinedId && !hasChildren());
d_ch1 = ch1;
d_ch0 = ch0;
}
std::string debugPrint() const;
};
-
- // the equality nodes in the union find
- std::vector<EqualityNode> d_equalityNodes;
-
- /// getter methods for the internal nodes
- TermId getRepr(TermId id) const;
- ExplanationId getReason(TermId id) const;
- TermId getChild(TermId id, Index i) const;
- Index getCutPoint(TermId id) const;
- bool hasChildren(TermId id) const;
- /// setter methods for the internal nodes
- void setRepr(TermId id, TermId new_repr, ExplanationId reason);
- void setChildren(TermId id, TermId ch1, TermId ch0);
-
- // the mappings between ExtractTerms and ids
- __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> > d_idToExtract;
- __gnu_cxx::hash_map<ExtractTerm, TermId, ExtractTermHashFunction > d_extractToId;
-
- __gnu_cxx::hash_set<TermId> d_topLevelIds;
+ /// map from TermId to the nodes that represent them
+ std::vector<Node> d_nodes;
+ /// a term is in this set if it is its own representative
+ TermSet d_representatives;
void getDecomposition(const ExtractTerm& term, Decomposition& decomp);
- void getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, std::vector<ExplanationId>& explanation);
void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common);
-
- /* Backtracking mechanisms */
-
- enum OperationKind {
- MERGE,
- SPLIT
- };
-
- struct Operation {
- OperationKind op;
- TermId id;
- Operation(OperationKind o, TermId i)
- : op(o), id(i) {}
- };
-
- std::vector<Operation> d_undoStack;
- context::CDO<unsigned> d_undoStackIndex;
+ /// getter methods for the internal nodes
+ TermId getRepr(TermId id) const {
+ Assert (id < d_nodes.size());
+ return d_nodes[id].getRepr();
+ }
+ TermId getChild(TermId id, Index i) const {
+ Assert (id < d_nodes.size());
+ return d_nodes[id].getChild(i);
+ }
+ Index getCutPoint(TermId id) const {
+ return getBitwidth(getChild(id, 0));
+ }
+ bool hasChildren(TermId id) const {
+ Assert (id < d_nodes.size());
+ return d_nodes[id].hasChildren();
+ }
+ /// setter methods for the internal nodes
+ void setRepr(TermId id, TermId new_repr) {
+ Assert (id < d_nodes.size());
+ d_nodes[id].setRepr(new_repr);
+ }
+ void setChildren(TermId id, TermId ch1, TermId ch0) {
+ Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0));
+ d_nodes[id].setChildren(ch1, ch0);
+ }
- void backtrack();
- void undoMerge(TermId id);
- void undoSplit(TermId id);
- void recordOperation(OperationKind op, TermId term);
- virtual ~UnionFind() throw(AssertionException) {}
class Statistics {
public:
IntStat d_numNodes;
@@ -249,106 +195,56 @@ class UnionFind : public context::ContextNotifyObj {
IntStat d_numMerges;
AverageStat d_avgFindDepth;
ReferenceStat<unsigned> d_numAddedEqualities;
+ //IntStat d_numAddedEqualities;
Statistics();
~Statistics();
};
- Statistics d_statistics;
- Slicer* d_slicer;
- TermId d_termIdCount;
- TermId mkEqualityNode(Index bitwidth);
- ExtractTerm mkExtractTerm(TermId id, Index high, Index low);
- void storeExtractTerm(Index id, const ExtractTerm& term);
- ExtractTerm getExtractTerm(TermId id) const;
- bool extractHasId(const ExtractTerm& ex) const { return d_extractToId.find(ex) != d_extractToId.end(); }
- TermId getExtractId(const ExtractTerm& ex) const {Assert (extractHasId(ex)); return d_extractToId.find(ex)->second; }
- bool isExtractTerm(TermId id) const;
+ Statistics d_statistics
+;
+
public:
- UnionFind(context::Context* ctx, Slicer* slicer)
- : ContextNotifyObj(ctx),
- d_equalityNodes(),
- d_idToExtract(),
- d_extractToId(),
- d_topLevelIds(),
- d_undoStack(),
- d_undoStackIndex(ctx),
- d_statistics(),
- d_slicer(slicer),
- d_termIdCount(0)
+ UnionFind()
+ : d_nodes(),
+ d_representatives()
{}
- TermId addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low);
- TermId registerTopLevelTerm(Index bitwidth);
- void unionTerms(TermId id1, TermId id2, TermId reason);
- void merge(TermId t1, TermId t2, TermId reason);
- TermId find(TermId t1);
- TermId findWithExplanation(TermId id, std::vector<ExplanationId>& explanation);
+ TermId addTerm(Index bitwidth);
+ void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2);
+ void merge(TermId t1, TermId t2);
+ TermId find(TermId t1);
void split(TermId term, Index i);
+
void getNormalForm(const ExtractTerm& term, NormalForm& nf);
- void getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, std::vector<ExplanationId>& explanation);
- void alignSlicings(TermId id1, TermId id2);
- void ensureSlicing(TermId id);
+ void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2);
+ void ensureSlicing(const ExtractTerm& term);
Index getBitwidth(TermId id) const {
- Assert (id < d_equalityNodes.size());
- return d_equalityNodes[id].getBitwidth();
+ Assert (id < d_nodes.size());
+ return d_nodes[id].getBitwidth();
}
- void getBase(TermId id, Base& base, Index offset);
std::string debugPrint(TermId id);
-
- void contextNotifyPop() {
- backtrack();
- }
friend class Slicer;
};
-class CoreSolver;
-
class Slicer {
- __gnu_cxx::hash_map<TermId, TNode> d_idToNode;
+ __gnu_cxx::hash_map<TermId, TNode> d_idToNode;
__gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction> d_nodeToId;
__gnu_cxx::hash_map<TNode, bool, TNodeHashFunction> d_coreTermCache;
- __gnu_cxx::hash_map<TNode, ExplanationId, NodeHashFunction> d_explanationToId;
- std::vector<TNode> d_explanations;
UnionFind d_unionFind;
-
- context::CDQueue<Node> d_newSplits;
- context::CDO<unsigned> d_newSplitsIndex;
- CoreSolver* d_coreSolver;
- TermId registerTopLevelTerm(TNode node);
- bool isTopLevelNode(TermId id) const;
- TermId registerTerm(TNode node);
+ ExtractTerm registerTerm(TNode node);
public:
- Slicer(context::Context* ctx, CoreSolver* coreSolver)
+ Slicer()
: d_idToNode(),
d_nodeToId(),
d_coreTermCache(),
- d_explanationToId(),
- d_explanations(),
- d_unionFind(ctx, this),
- d_newSplits(ctx),
- d_newSplitsIndex(ctx, 0),
- d_coreSolver(coreSolver)
+ d_unionFind()
{}
- void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation);
- void registerEquality(TNode eq);
-
+ void getBaseDecomposition(TNode node, std::vector<Node>& decomp);
void processEquality(TNode eq);
- void assertEquality(TNode eq);
bool isCoreTerm (TNode node);
-
- bool hasNode(TermId id) const;
- Node getNode(TermId id) const;
-
- bool hasExplanation(ExplanationId id) const;
- TNode getExplanation(ExplanationId id) const;
- ExplanationId getExplanationId(TNode reason) const;
-
- bool termInEqualityEngine(TermId id);
- void enqueueSplit(TermId id, Index i, TermId top, TermId bottom);
- void getNewSplits(std::vector<Node>& splits);
static void splitEqualities(TNode node, std::vector<Node>& equalities);
- static unsigned d_numAddedEqualities;
+ static unsigned d_numAddedEqualities;
};
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback