summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDejan Jovanović <dejan.jovanovic@gmail.com>2011-03-20 01:12:31 +0000
committerDejan Jovanović <dejan.jovanovic@gmail.com>2011-03-20 01:12:31 +0000
commitaf6ac1f03a841a0261190cb7caa15ff1fa1f798c (patch)
tree56351c49de0cd299548becb15bf5810d6e0dac54 /src
parent649c50afb9e35ef467828567d4b1d24a107d6d20 (diff)
commit for the version of bitvectors that passes all the unit tests
Diffstat (limited to 'src')
-rw-r--r--src/theory/bv/cd_set_collection.h109
-rw-r--r--src/theory/bv/equality_engine.h299
-rw-r--r--src/theory/bv/slice_manager.h384
-rw-r--r--src/theory/bv/theory_bv.cpp40
-rw-r--r--src/theory/bv/theory_bv.h29
-rw-r--r--src/theory/bv/theory_bv_rewriter.cpp6
-rw-r--r--src/theory/bv/theory_bv_utils.h88
-rw-r--r--src/util/bitvector.h2
8 files changed, 645 insertions, 312 deletions
diff --git a/src/theory/bv/cd_set_collection.h b/src/theory/bv/cd_set_collection.h
index 33648660b..aeb28ab7b 100644
--- a/src/theory/bv/cd_set_collection.h
+++ b/src/theory/bv/cd_set_collection.h
@@ -123,28 +123,30 @@ public:
return newElement(value, null, null, null, false);
}
- void insert(memory_type& memory, reference_type root, const value_type& value) {
+ void insert(reference_type& root, const value_type& value) {
backtrack();
if (root == null) {
- return newSet(value);
+ root = newSet(value);
+ return;
}
// We already have a set, find the spot
reference_type parent = null;
+ reference_type current = root;
while (true) {
- parent = root;
- if (value < d_memory[root].value) {
- root = d_memory[root].left;
- if (root == null) {
- root = newElement(value, null, null, parent, true);
- d_memory[parent].left = root;
+ parent = current;
+ if (value < d_memory[current].getValue()) {
+ if (d_memory[current].hasLeft()) {
+ current = d_memory[current].getLeft();
+ } else {
+ d_memory[current].setLeft(newElement(value, null, null, parent, true));
return;
}
} else {
- Assert(value != d_memory[root].value);
- root = d_memory[root].right;
- if (root == null) {
- root = newElement(value, null, null, parent, false);
- d_memory[parent].right = root;
+ Assert(value != d_memory[root].getValue());
+ if (d_memory[current].hasRight()) {
+ current = d_memory[current].getRight();
+ } else {
+ d_memory[parent].setRight(newElement(value, null, null, parent, false));
return;
}
}
@@ -174,49 +176,55 @@ public:
*/
const_value_reference prev(reference_type set, const_value_reference value) {
backtrack();
- // Get the node of this value
- reference_type node_ref = find(set, value);
- Assert(node_ref != null);
- const tree_entry_type& node = d_memory[node_ref];
- // For a left node, we know that it is smaller than all the parents and the parents other children
- // The smaller node must then be the max of the left subtree
- if (!node.hasParent() || node.isLeft()) {
- return maxElement(node.getLeft());
- }
- // For a right node, we know that it is bigger than the parent. But, we also know that the left subtree
- // is also bigger than the parent
- else {
- if (node.hasLeft()) {
- return maxElement(node.getLeft());
+
+ const_value_reference candidate_value;
+ bool candidate_found = false;
+
+ // Find the biggest node smaleer than value (it must exist)
+ while (set != null) {
+ Debug("set_collection") << "BacktrackableSetCollection::getPrev(" << toString(set) << "," << value << ")" << std::endl;
+ const tree_entry_type& node = d_memory[set];
+ if (node.getValue() >= value) {
+ // If the node is bigger than the value, we need a smaller one
+ set = node.getLeft();
} else {
- Assert(node.hasParent());
- return d_memory[node.getParent()].getValue();
+ // The node is smaller than the value
+ candidate_found = true;
+ candidate_value = node.getValue();
+ // There might be a bigger one
+ set = node.getRight();
}
}
+
+ Assert(candidate_found);
+ return candidate_value;
}
const_value_reference next(reference_type set, const_value_reference value) {
backtrack();
- // Get the node of this value
- reference_type node_ref = find(set, value);
- Assert(node_ref != null);
- const tree_entry_type& node = d_memory[node_ref];
- // For a right node, we know that it is bigger than all the parents and the parents other children
- // The bigger node must then be the min of the right subtree
- if (!node.hasParent() || node.isRight()) {
- return minElement(node.getRight());
- }
- // For a left node, we know that it is smaller than the parent. But, we also know that the right subtree
- // is also smaller than the parent
- else {
- if (node.hasRight()) {
- return minElement(node.getRight());
+
+ const_value_reference candidate_value;
+ bool candidate_found = false;
+
+ // Find the smallest node bigger than value (it must exist)
+ while (set != null) {
+ Debug("set_collection") << "BacktrackableSetCollection::getNext(" << toString(set) << "," << value << ")" << std::endl;
+ const tree_entry_type& node = d_memory[set];
+ if (node.getValue() <= value) {
+ // If the node is smaller than the value, we need a bigger one
+ set = node.getRight();
} else {
- Assert(node.hasParent());
- return d_memory[node.getParent()].getValue();
+ // The node is bigger than the value
+ candidate_found = true;
+ candidate_value = node.getValue();
+ // There might be a smaller one
+ set = node.getLeft();
}
}
- }
+
+ Assert(candidate_found);
+ return candidate_value;
+}
/**
* Count the number of elements in the given bounds.
@@ -262,6 +270,9 @@ public:
void getElements(reference_type set, const_value_reference lowerBound, const_value_reference upperBound, std::vector<value_type>& output) const {
Assert(lowerBound <= upperBound);
backtrack();
+
+ Debug("set_collection") << "BacktrackableSetCollection::getElements(" << toString(set) << "," << lowerBound << "," << upperBound << ")" << std::endl;
+
// Empty set no elements
if (set == null) {
return;
@@ -277,7 +288,7 @@ public:
output.push_back(current.getValue());
}
// Right child (bigger elements)
- if (current.getValue() <= upperBound) {
+ if (current.getValue() < upperBound) {
getElements(current.getRight(), lowerBound, upperBound, output);
}
}
@@ -285,7 +296,7 @@ public:
/**
* Print the list of elements to the output.
*/
- void print(std::ostream& out, reference_type set) {
+ void print(std::ostream& out, reference_type set) const {
backtrack();
if (set == null) {
return;
@@ -305,7 +316,7 @@ public:
/**
* String representation of a set.
*/
- std::string toString(reference_type set) {
+ std::string toString(reference_type set) const {
std::stringstream out;
print(out, set);
return out.str();
diff --git a/src/theory/bv/equality_engine.h b/src/theory/bv/equality_engine.h
index 9880539ed..53c44bed0 100644
--- a/src/theory/bv/equality_engine.h
+++ b/src/theory/bv/equality_engine.h
@@ -28,13 +28,13 @@
#include "context/cdo.h"
#include "util/output.h"
#include "util/stats.h"
+#include "theory/rewriter.h"
namespace CVC4 {
namespace theory {
namespace bv {
struct BitSizeTraits {
-
/** The null id */
static const size_t id_null; // Defined in the cpp file (GCC bug)
/** The null trigger id */
@@ -46,13 +46,6 @@ struct BitSizeTraits {
static const size_t size_bits = 16;
/** Number of bits we use for the trigger id */
static const size_t trigger_id_bits = 24;
-
- /** Number of bits we use for the function ids */
- static const size_t function_id_bits = 8;
- /** Number of bits we use for the function arguments count */
- static const size_t function_arguments_count_bits = 16;
- /** Number of bits we use for the index into the arguments memory */
- static const size_t function_arguments_index_bits = 24;
};
class EqualityNode {
@@ -68,22 +61,18 @@ public:
/** The next equality node in this class */
size_t d_nextId : BitSizeTraits::id_bits;
- /** Is this node a function application */
- size_t d_isFunction : 1;
-
public:
/**
* Creates a new node, which is in a list of it's own.
*/
EqualityNode(size_t nodeId = BitSizeTraits::id_null)
- : d_size(1), d_findId(nodeId), d_nextId(nodeId), d_isFunction(0) {}
+ : d_size(1), d_findId(nodeId), d_nextId(nodeId) {}
/** Initialize the equality node */
- inline void init(size_t nodeId, bool isFunction) {
+ inline void init(size_t nodeId) {
d_size = 1;
d_findId = d_nextId = nodeId;
- d_isFunction = isFunction;
}
/**
@@ -125,66 +114,11 @@ public:
inline void setFind(size_t findId) { d_findId = findId; }
};
-/**
- * FunctionNode class represents the information related to a function node. It has an id, number of children
- * and the
- */
-class FunctionNode {
-
- /** Is the function associative */
- size_t d_isAssociative : 1;
- /** The id of the function */
- size_t d_functionId : BitSizeTraits::function_id_bits;
- /** Number of children */
- size_t d_argumentsCount : BitSizeTraits::function_arguments_count_bits;
- /** Index of the start of the arguments in the children array */
- size_t d_argumentsIndex : BitSizeTraits::function_arguments_index_bits;
-
-public:
-
- FunctionNode(size_t functionId = 0, size_t argumentsCount = 0, size_t argumentsIndex = 0, bool associative = false)
- : d_isAssociative(associative), d_functionId(functionId), d_argumentsCount(argumentsCount), d_argumentsIndex(argumentsIndex)
- {}
-
- void init(size_t functionId, size_t argumentsCount, size_t argumentsIndex, bool associative) {
- d_functionId = functionId;
- d_argumentsCount = argumentsCount;
- d_argumentsIndex = argumentsIndex;
- d_isAssociative = associative;
- }
-
- /** Check if the function is associative */
- bool isAssociative() const { return d_isAssociative; }
-
- /** Get the function id */
- size_t getFunctionId() const { return d_functionId; }
-
- /** Get the number of arguments */
- size_t getArgumentsCount() const { return d_argumentsCount; }
-
- /** Get the infex of the first argument in the arguments memory */
- size_t getArgumentsIndex() const { return d_argumentsIndex; }
-
-};
-
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
class EqualityEngine {
public:
- /**
- * Basic information about a function.
- */
- struct FunctionInfo {
- /** Name of the function */
- std::string name;
- /** Is the function associative */
- bool isAssociative;
-
- FunctionInfo(std::string name, bool isAssociative)
- : name(name), isAssociative(isAssociative) {}
- };
-
/** Statistics about the equality engine instance */
struct Statistics {
/** Total number of merges */
@@ -193,8 +127,6 @@ public:
IntStat termsCount;
/** Number of function terms managed by the system */
IntStat functionTermsCount;
- /** Number of distince functions managed by the system */
- IntStat functionsCount;
/** Number of times we performed a backtrack */
IntStat backtracksCount;
@@ -202,13 +134,11 @@ public:
: mergesCount(name + "::mergesCount", 0),
termsCount(name + "::termsCount", 0),
functionTermsCount(name + "functionTermsCoutn", 0),
- functionsCount(name + "::functionsCount", 0),
backtracksCount(name + "::backtracksCount", 0)
{
StatisticsRegistry::registerStat(&mergesCount);
StatisticsRegistry::registerStat(&termsCount);
StatisticsRegistry::registerStat(&functionTermsCount);
- StatisticsRegistry::registerStat(&functionsCount);
StatisticsRegistry::registerStat(&backtracksCount);
}
@@ -216,7 +146,6 @@ public:
StatisticsRegistry::unregisterStat(&mergesCount);
StatisticsRegistry::unregisterStat(&termsCount);
StatisticsRegistry::unregisterStat(&functionTermsCount);
- StatisticsRegistry::unregisterStat(&functionsCount);
StatisticsRegistry::unregisterStat(&backtracksCount);
}
};
@@ -238,12 +167,6 @@ private:
/** Number of asserted equalities we have so far */
context::CDO<size_t> d_assertedEqualitiesCount;
- /** Map from ids to functional representations */
- std::vector<FunctionNode> d_functionNodes;
-
- /** Functions in the system */
- std::vector<FunctionInfo> d_functions;
-
/**
* We keep a list of asserted equalities. Not among original terms, but
* among the class representatives.
@@ -261,6 +184,8 @@ private:
/** The ids of the classes we have merged */
std::vector<Equality> d_assertedEqualities;
+ /** The reasons for the equalities */
+
/**
* An edge in the equality graph. This graph is an undirected graph (both edges added)
* containing the actual asserted equalities.
@@ -292,13 +217,18 @@ private:
std::vector<EqualityEdge> d_equalityEdges;
/**
+ * Reasons for equalities.
+ */
+ std::vector<Node> d_equalityReasons;
+
+ /**
* Map from a node to it's first edge in the equality graph. Edges are added to the front of the
* list which makes the insertion/backtracking easy.
*/
std::vector<size_t> d_equalityGraph;
/** Add an edge to the equality graph */
- inline void addGraphEdge(size_t t1, size_t t2);
+ inline void addGraphEdge(size_t t1, size_t t2, Node reason);
/** Returns the equality node of the given node */
inline EqualityNode& getEqualityNode(TNode node);
@@ -386,11 +316,6 @@ public:
size_t addTerm(TNode t);
/**
- * Adds a term that is an application of a function symbol to the databas. Returns the internal id of the term.
- */
- size_t addFunctionApplication(size_t funcionId, const std::vector<TNode>& arguments);
-
- /**
* Check whether the node is already in the database.
*/
inline bool hasTerm(TNode t) const;
@@ -398,7 +323,7 @@ public:
/**
* Adds an equality t1 = t2 to the database. Returns false if any of the triggers failed.
*/
- bool addEquality(TNode t1, TNode t2);
+ bool addEquality(TNode t1, TNode t2, Node reason);
/**
* Returns the representative of the term t.
@@ -424,23 +349,27 @@ public:
size_t addTrigger(TNode t1, TNode t2);
/**
- * Adds a new function to the equality engine. The funcions are not of fixed arity and no typechecking is performed!
- * Associative functions allow for normalization, i.e. f(f(x, y), z) = f(x, f(y, z)) = f(x, y, z).
- * @associative should be true if the function is associative and you want this to be handled by the engine
+ * Normalizes a term by finding the representative. If the representative can be decomposed (using
+ * UnionFindPreferences) it will try and recursively find the representatives, and substitute.
+ * Assumptions used in normalization are retruned in the set.
*/
- inline size_t newFunction(std::string name, bool associative) {
- Assert(use_functions);
- Assert(!associative || enable_associative);
- ++ d_stats.functionsCount;
- size_t id = d_functions.size();
- d_functions.push_back(FunctionInfo(name, associative));
- return id;
- }
+ Node normalize(TNode node, std::set<TNode>& assumptions);
+
+private:
+
+ /** Hash of normalizations to avioid cycles */
+ typedef __gnu_cxx::hash_map<TNode, Node, TNodeHashFunction> normalization_cache;
+ normalization_cache d_normalizationCache;
+
+ /**
+ * Same as above, but does cahcing to avoid loops.
+ */
+ Node normalizeWithCache(TNode node, std::set<TNode>& assumptions);
};
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::addTerm(TNode t) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+size_t EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::addTerm(TNode t) {
Debug("equality") << "EqualityEngine::addTerm(" << t << ")" << std::endl;
@@ -462,69 +391,35 @@ size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative
if (d_equalityNodes.size() <= newId) {
d_equalityNodes.resize(newId + 100);
}
- d_equalityNodes[newId].init(newId, false);
- // Return the id of the term
- return newId;
-}
-
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::addFunctionApplication(size_t functionId, const std::vector<TNode>& arguments) {
-
- Debug("equality") << "EqualityEngine::addFunctionApplication(" << d_functions[functionId].name << ":" << arguments.size() << ")" << std::endl;
-
- ++ d_stats.functionTermsCount;
- ++ d_stats.termsCount;
-
- // Register the new id of the term
- size_t newId = d_nodes.size();
- // Add the node to it's position
- d_nodes.push_back(Node());
- // Add the trigger list for this node
- d_nodeTriggers.push_back(BitSizeTraits::trigger_id_null);
- // Add it to the equality graph
- d_equalityGraph.push_back(BitSizeTraits::id_null);
- // Add the equality node to the nodes
- if (d_equalityNodes.size() <= newId) {
- d_equalityNodes.resize(newId + 100);
- }
- d_equalityNodes[newId].init(newId, true);
- // Add the function application to the function nodes
- if (d_functionNodes.size() <= newId) {
- d_functionNodes.resize(newId + 100);
- }
- // Initialize the function node
- size_t argumentsIndex;
- d_functionNodes[newId].init(functionId, arguments.size(), argumentsIndex, d_functions[functionId].isAssociative);
-
+ d_equalityNodes[newId].init(newId);
// Return the id of the term
return newId;
-
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::hasTerm(TNode t) const {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+bool EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::hasTerm(TNode t) const {
return d_nodeIds.find(t) != d_nodeIds.end();
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::getNodeId(TNode node) const {
- Assert(hasTerm(node));
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+size_t EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::getNodeId(TNode node) const {
+ Assert(hasTerm(node), node.toString().c_str());
return (*d_nodeIds.find(node)).second;
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-EqualityNode& EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::getEqualityNode(TNode t) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+EqualityNode& EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::getEqualityNode(TNode t) {
return getEqualityNode(getNodeId(t));
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-EqualityNode& EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::getEqualityNode(size_t nodeId) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+EqualityNode& EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::getEqualityNode(size_t nodeId) {
Assert(nodeId < d_equalityNodes.size());
return d_equalityNodes[nodeId];
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::addEquality(TNode t1, TNode t2) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+bool EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::addEquality(TNode t1, TNode t2, Node reason) {
Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << ")" << std::endl;
@@ -549,18 +444,20 @@ bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
Assert(node1.getFind() == t1classId);
Assert(node2.getFind() == t2classId);
- // Depending on the size, merge them
+ // Depending on the merge preference (such as size), merge them
std::vector<size_t> triggers;
- if (node1.getSize() < node2.getSize()) {
+ if (UnionFindPreferences::mergePreference(d_nodes[t2classId], node2.getSize(), d_nodes[t1classId], node1.getSize())) {
+ Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << "): merging " << t1 << " into " << t2 << std::endl;
merge(node2, node1, triggers);
d_assertedEqualities.push_back(Equality(t2classId, t1classId));
} else {
+ Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << "): merging " << t2 << " into " << t1 << std::endl;
merge(node1, node2, triggers);
d_assertedEqualities.push_back(Equality(t1classId, t2classId));
}
// Add the actuall equality to the equality graph
- addGraphEdge(t1Id, t2Id);
+ addGraphEdge(t1Id, t2Id, reason);
// One more equality added
d_assertedEqualitiesCount = d_assertedEqualitiesCount + 1;
@@ -577,8 +474,8 @@ bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
return true;
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-TNode EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::getRepresentative(TNode t) const {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+TNode EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::getRepresentative(TNode t) const {
Debug("equality") << "EqualityEngine::getRepresentative(" << t << ")" << std::endl;
@@ -593,8 +490,8 @@ TNode EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>
return d_nodes[representativeId];
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::areEqual(TNode t1, TNode t2) const {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+bool EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::areEqual(TNode t1, TNode t2) const {
Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl;
Assert(hasTerm(t1));
@@ -610,8 +507,8 @@ bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
return rep1 == rep2;
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::merge(EqualityNode& class1, EqualityNode& class2, std::vector<size_t>& triggers) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+void EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::merge(EqualityNode& class1, EqualityNode& class2, std::vector<size_t>& triggers) {
Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl;
@@ -660,8 +557,8 @@ void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
class1.merge<true>(class2);
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::undoMerge(EqualityNode& class1, EqualityNode& class2, size_t class2Id) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+void EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::undoMerge(EqualityNode& class1, EqualityNode& class2, size_t class2Id) {
Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << ")" << std::endl;
@@ -692,8 +589,8 @@ void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::backtrack() {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+void EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::backtrack() {
// If we need to backtrack then do it
if (d_assertedEqualitiesCount < d_assertedEqualities.size()) {
@@ -721,22 +618,24 @@ void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
}
d_equalityEdges.resize(2 * d_assertedEqualitiesCount);
+ d_equalityReasons.resize(d_assertedEqualitiesCount);
}
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::addGraphEdge(size_t t1, size_t t2) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+void EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::addGraphEdge(size_t t1, size_t t2, Node reason) {
Debug("equality") << "EqualityEngine::addGraphEdge(" << d_nodes[t1] << "," << d_nodes[t2] << ")" << std::endl;
size_t edge = d_equalityEdges.size();
d_equalityEdges.push_back(EqualityEdge(t2, d_equalityGraph[t1]));
d_equalityEdges.push_back(EqualityEdge(t1, d_equalityGraph[t2]));
d_equalityGraph[t1] = edge;
d_equalityGraph[t2] = edge | 1;
+ d_equalityReasons.push_back(reason);
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::getExplanation(TNode t1, TNode t2, std::vector<TNode>& equalities) const {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+void EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::getExplanation(TNode t1, TNode t2, std::vector<TNode>& equalities) const {
Assert(equalities.empty());
Assert(t1 != t2);
Assert(getRepresentative(t1) == getRepresentative(t2));
@@ -784,15 +683,9 @@ void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
// Reconstruct the path
do {
- // Get the left and right hand side from the edge
- size_t firstEdge = (currentEdge >> 1) << 1;
- size_t secondEdge = (currentEdge | 1);
- TNode lhs = d_nodes[d_equalityEdges[secondEdge].getNodeId()];
- TNode rhs = d_nodes[d_equalityEdges[firstEdge].getNodeId()];
// Add the actual equality to the vector
- equalities.push_back(lhs.eqNode(rhs));
-
- Debug("equality") << "EqualityEngine::getExplanation(): adding: " << lhs.eqNode(rhs) << std::endl;
+ equalities.push_back(d_equalityReasons[currentEdge >> 1]);
+ Debug("equality") << "EqualityEngine::getExplanation(): adding: " << d_equalityReasons[currentEdge >> 1] << std::endl;
// Go to the previous
currentEdge = bfsQueue[currentIndex].edgeId;
@@ -816,8 +709,8 @@ void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
}
}
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::addTrigger(TNode t1, TNode t2) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+size_t EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::addTrigger(TNode t1, TNode t2) {
Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ")" << std::endl;
@@ -850,6 +743,64 @@ size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative
return t1NewTriggerId / 2;
}
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+Node EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::normalize(TNode node, std::set<TNode>& assumptions) {
+ d_normalizationCache.clear();
+ Node result = Rewriter::rewrite(normalizeWithCache(node, assumptions));
+ d_normalizationCache.clear();
+ return result;
+}
+
+
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+Node EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::normalizeWithCache(TNode node, std::set<TNode>& assumptions) {
+
+ Debug("equality") << "EqualityEngine::normalize(" << node << ")" << push << std::endl;
+
+ normalization_cache::iterator find = d_normalizationCache.find(node);
+ if (find != d_normalizationCache.end()) {
+ if (find->second.isNull()) {
+ // We are in a cycle
+ return node;
+ } else {
+ // Not in a cycle, return it
+ return find->second;
+ }
+ } else {
+ d_normalizationCache[node] = Node();
+ }
+
+ // Get the representative
+ Node result = hasTerm(node) ? getRepresentative(node) : node;
+ if (node != result) {
+ std::vector<TNode> equalities;
+ getExplanation(result, node, equalities);
+ assumptions.insert(equalities.begin(), equalities.end());
+ }
+
+ // If asked, substitute the children with their representatives
+ if (UnionFindPreferences::descend(result)) {
+ // Make the builder for substitution
+ NodeBuilder<> builder;
+ builder << result.getKind();
+ kind::MetaKind metaKind = result.getMetaKind();
+ if (metaKind == kind::metakind::PARAMETERIZED) {
+ builder << result.getOperator();
+ }
+ for (unsigned i = 0; i < result.getNumChildren(); ++ i) {
+ builder << normalizeWithCache(result[i], assumptions);
+ }
+ result = builder;
+ }
+
+ Debug("equality") << "EqualityEngine::normalize(" << node << ") => " << result << pop << std::endl;
+
+ // Cache the result for real now
+ d_normalizationCache[node] = result;
+
+ return result;
+}
+
} // Namespace bv
} // Namespace theory
} // Namespace CVC4
diff --git a/src/theory/bv/slice_manager.h b/src/theory/bv/slice_manager.h
index 8fc1e0b9d..96a0067dc 100644
--- a/src/theory/bv/slice_manager.h
+++ b/src/theory/bv/slice_manager.h
@@ -13,6 +13,7 @@
#include "theory/bv/cd_set_collection.h"
#include <map>
+#include <set>
#include <vector>
namespace CVC4 {
@@ -134,9 +135,6 @@ private:
/** The equality engine */
EqualityEngine& d_equalityEngine;
- /** The id of the concatenation function */
- size_t d_concatFunctionId;
-
/** The collection of backtrackable sets */
set_collection d_setCollection;
@@ -153,27 +151,35 @@ public:
SliceManager(TheoryBitvector& theoryBitvector, context::Context* context)
: d_theoryBitvector(theoryBitvector), d_equalityEngine(theoryBitvector.getEqualityEngine()), d_setCollection(context) {
- // We register the concatentation with the equality engine
- d_concatFunctionId = d_equalityEngine.newFunction("bv_concat", true);
}
- inline size_t getConcatFunctionId() const { return d_concatFunctionId; }
-
/**
- * Adds the equality (lhs = rhs) to the slice manager. This will not add the equalities to the equality manager,
- * but will slice the equality according to the current slicing in order to align all the slices. The terms that
- * get slices get sent to the theory engine as equalities, i.e if we slice x[10:0] into x[10:5]@x[4:0] equality
- * engine gets the assertion x[10:0] = concat(x[10:5], x[4:0]).
+ * Adds the equality (lhs = rhs) to the slice manager. The equality is first normalized according to the equality
+ * manager, i.e. each base term is taken from the equality manager, replaced in, and then the whole concatenation
+ * normalized and sliced wrt the current slicing. The method will not add the equalities to the equality manager,
+ * but instead will slice the equality according to the current slicing in order to align all the slices.
+ *
+ * The terms that get sliced get sent to the theory engine as equalities, i.e if we slice x[10:0] into x[10:5]@x[4:0]
+ * equality engine gets the assertion x[10:0] = concat(x[10:5], x[4:0]).
+ *
+ * input output slicing
+ * --------------------------------------------------------------------------------------------------------------
+ * x@y = y@x x = y, y = x empty
+ * x[31:0]@x[64:32] = x x = x[31:0]@x[63:32] x:{64,32,0}
+ * x@y = 0000@x@0000 x = 0000@x[7:4], y = x[3:0]@0000 x:{8,4,0}
+ *
*/
- inline void addEquality(TNode lhs, TNode rhs, std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices);
+ inline bool solveEquality(TNode lhs, TNode rhs);
private:
+ inline bool solveEquality(TNode lhs, TNode rhs, const std::set<TNode>& assumptions);
+
/**
- * Slices up lhs and rhs and returns the slices in lhsSlices and rhsSlices
+ * Slices up lhs and rhs and returns the slices in lhsSlices and rhsSlices. The slices are not atomic,
+ * they are sliced in order to make one of lhs or rhs atomic, the other one can be a concatenation.
*/
- inline void slice(std::vector<Node>& lhs, std::vector<Node>& rhs,
- std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices);
+ inline bool sliceAndSolve(std::vector<Node>& lhs, std::vector<Node>& rhs, const std::set<TNode>& assumptions);
/**
* Returns true if the term is already sliced wrt the current slicing. Note that, for example, even though
@@ -184,7 +190,7 @@ private:
/**
* Slices the term wrt the current slicing. When done, isSliced returns true
*/
- inline void slice(TNode node, std::vector<Node>& sliced);
+ inline bool slice(TNode node, std::vector<Node>& sliced);
/**
* Returns the base term in the core theory of the given term, i.e.
@@ -196,20 +202,87 @@ private:
static inline TNode baseTerm(TNode node);
/**
- * Adds a new slice to the slice set of the given base term.
+ * Adds a new slice to the slice set of the given term.
*/
- inline void addSlice(Node baseTerm, unsigned slicePoint);
+ inline bool addSlice(Node term, unsigned slicePoint);
};
template <class TheoryBitvector>
-void SliceManager<TheoryBitvector>::addSlice(Node baseTerm, unsigned slicePoint) {
+bool SliceManager<TheoryBitvector>::addSlice(Node node, unsigned slicePoint) {
+ Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << ")" << std::endl;
+
+ bool ok = true;
+
+ int low = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractLow(node) : 0;
+ int high = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractHigh(node) + 1: utils::getSize(node);
+ slicePoint += low;
+
+ TNode nodeBase = baseTerm(node);
+
+ set_reference sliceSet;
+ slicing_map::iterator find = d_nodeSlicing.find(nodeBase);
+ if (find == d_nodeSlicing.end()) {
+ sliceSet = d_nodeSlicing[nodeBase] = d_setCollection.newSet(slicePoint);
+ d_setCollection.insert(sliceSet, low);
+ d_setCollection.insert(sliceSet, high);
+ } else {
+ sliceSet = find->second;
+ }
+
+ // What are the points surrounding the new slice point
+ int prev = d_setCollection.prev(sliceSet, slicePoint);
+ int next = d_setCollection.next(sliceSet, slicePoint);
+
+ // Add the slice to the set
+ d_setCollection.insert(sliceSet, slicePoint);
+ Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << "): current set " << d_setCollection.toString(sliceSet) << std::endl;
+
+ // Add the terms and the equality to the equality engine
+ Node t1 = utils::mkExtract(nodeBase, next - 1, slicePoint);
+ Node t2 = utils::mkExtract(nodeBase, slicePoint - 1, prev);
+ Node nodeSlice = (next == high && prev == low) ? node : utils::mkExtract(nodeBase, next - 1, prev);
+ Node concat = utils::mkConcat(t1, t2);
+
+ d_equalityEngine.addTerm(t1);
+ d_equalityEngine.addTerm(t2);
+ d_equalityEngine.addTerm(nodeSlice);
+ d_equalityEngine.addTerm(concat);
+
+ // We are free to add this slice, unless the slice has a representative that's already a concat
+ TNode nodeSliceRepresentative = d_equalityEngine.getRepresentative(nodeSlice);
+ if (nodeSliceRepresentative.getKind() != kind::BITVECTOR_CONCAT) {
+ // Add the slice to the equality engine
+ ok = d_equalityEngine.addEquality(nodeSlice, concat, utils::mkTrue());
+ } else {
+ // If the representative is a concat, we must solve it
+ // There is no need do add nodeSlice = concat as we will solve the representative of nodeSlice
+ std::set<TNode> assumptions;
+ std::vector<TNode> equalities;
+ d_equalityEngine.getExplanation(nodeSlice, nodeSliceRepresentative, equalities);
+ assumptions.insert(equalities.begin(), equalities.end());
+ ok = solveEquality(nodeSliceRepresentative, concat, assumptions);
+ }
+
+ Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << ") => " << d_setCollection.toString(d_nodeSlicing[nodeBase]) << std::endl;
+
+ return ok;
}
template <class TheoryBitvector>
-void SliceManager<TheoryBitvector>::addEquality(TNode lhs, TNode rhs, std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices) {
+bool SliceManager<TheoryBitvector>::solveEquality(TNode lhs, TNode rhs) {
+ std::set<TNode> assumptions;
+ assumptions.insert(lhs.eqNode(rhs));
+ bool ok = solveEquality(lhs, rhs, assumptions);
+ return ok;
+}
+
+template <class TheoryBitvector>
+bool SliceManager<TheoryBitvector>::solveEquality(TNode lhs, TNode rhs, const std::set<TNode>& assumptions) {
+
+ Debug("slicing") << "SliceMagager::solveEquality(" << lhs << "," << rhs << "," << utils::setToString(assumptions) << ")" << push << std::endl;
- Debug("slicing") << "SliceMagager::addEquality(" << lhs << "," << rhs << ")" << std::endl;
+ bool ok;
// The concatenations on the left-hand side (reverse order, first is on top)
std::vector<Node> lhsTerms;
@@ -232,60 +305,213 @@ void SliceManager<TheoryBitvector>::addEquality(TNode lhs, TNode rhs, std::vecto
}
// Slice the individual terms to align them
- slice(lhsTerms, rhsTerms, lhsSlices, rhsSlices);
+ ok = sliceAndSolve(lhsTerms, rhsTerms, assumptions);
+
+ Debug("slicing") << "SliceMagager::solveEquality(" << lhs << "," << rhs << "," << utils::setToString(assumptions) << ")" << pop << std::endl;
+
+ return ok;
}
+
template <class TheoryBitvector>
-void SliceManager<TheoryBitvector>::slice(std::vector<Node>& lhs, std::vector<Node>& rhs,
- std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices) {
+bool SliceManager<TheoryBitvector>::sliceAndSolve(std::vector<Node>& lhs, std::vector<Node>& rhs, const std::set<TNode>& assumptions)
+{
- Debug("slicing") << "SliceManager::slice()" << std::endl;
+ Debug("slicing") << "SliceManager::sliceAndSolve()" << std::endl;
- // Go through the work-list and align
+ // Go through the work-list, solve and align
while (!lhs.empty()) {
Assert(!rhs.empty());
+ Debug("slicing") << "SliceManager::sliceAndSolve(): lhs " << utils::vectorToString(lhs) << std::endl;
+ Debug("slicing") << "SliceManager::sliceAndSolve(): rhs " << utils::vectorToString(rhs) << std::endl;
+
// The terms that we need to slice
Node lhsTerm = lhs.back();
Node rhsTerm = rhs.back();
- Debug("slicing") << "slicing: " << lhsTerm << " and " << rhsTerm << std::endl;
+
+ Debug("slicing") << "SliceManager::sliceAndSolve(): " << lhsTerm << " : " << rhsTerm << std::endl;
// If the terms are not sliced wrt the current slicing, we have them sliced
lhs.pop_back();
if (!isSliced(lhsTerm)) {
- slice(lhsTerm, lhs);
+ if (!slice(lhsTerm, lhs)) return false;
+ Debug("slicing") << "SliceManager::sliceAndSolve(): lhs sliced" << std::endl;
continue;
}
rhs.pop_back();
if (!isSliced(rhsTerm)) {
- slice(rhsTerm, rhs);
+ if (!slice(rhsTerm, rhs)) return false;
+ // We also need to put lhs back
+ lhs.push_back(lhsTerm);
+ Debug("slicing") << "SliceManager::sliceAndSolve(): rhs sliced" << std::endl;
+ continue;
}
+ Debug("slicing") << "SliceManager::sliceAndSolve(): both lhs and rhs sliced already" << std::endl;
+
+ // The solving concatenation
+ std::vector<Node> concatTerms;
+
// If the slices are of the same size we do the additional work
- unsigned lhsSize = utils::getSize(lhsTerm);
- unsigned rhsSize = utils::getSize(rhsTerm);
- if (lhsSize == rhsSize) {
- // If they are over the same base terms, we need to do something
- TNode lhsBaseTerm = baseTerm(lhsTerm);
- TNode rhsBaseTerm = baseTerm(rhsTerm);
- if (lhsBaseTerm == rhsBaseTerm) {
- // x[i_1:j_1] vs x[i_2:j_2]
+ int sizeDifference = utils::getSize(lhsTerm) - utils::getSize(rhsTerm);
+
+ // We slice constants immediately
+ if (sizeDifference > 0 && lhsTerm.getKind() == kind::CONST_BITVECTOR) {
+ BitVector low = lhsTerm.getConst<BitVector>().extract(utils::getSize(rhsTerm) - 1, 0);
+ BitVector high = lhsTerm.getConst<BitVector>().extract(utils::getSize(lhsTerm) - 1, utils::getSize(rhsTerm));
+ lhs.push_back(utils::mkConst(low));
+ lhs.push_back(utils::mkConst(high));
+ rhs.push_back(rhsTerm);
+ continue;
+ }
+ if (sizeDifference < 0 && rhsTerm.getKind() == kind::CONST_BITVECTOR) {
+ BitVector low = rhsTerm.getConst<BitVector>().extract(utils::getSize(lhsTerm) - 1, 0);
+ BitVector high = rhsTerm.getConst<BitVector>().extract(utils::getSize(rhsTerm) - 1, utils::getSize(lhsTerm));
+ rhs.push_back(utils::mkConst(low));
+ rhs.push_back(utils::mkConst(high));
+ lhs.push_back(lhsTerm);
+ continue;
+ }
+
+ enum SolvingFor {
+ SOLVING_FOR_LHS,
+ SOLVING_FOR_RHS
+ } solvingFor = sizeDifference < 0 || lhsTerm.getKind() == kind::CONST_BITVECTOR ? SOLVING_FOR_RHS : SOLVING_FOR_LHS;
+
+ Debug("slicing") << "SliceManager::sliceAndSolve(): " << (solvingFor == SOLVING_FOR_LHS ? "solving for LHS" : "solving for RHS") << std::endl;
+
+ // When we slice in order to align, we might have to reslice the one we are solving for
+ bool reslice = false;
+
+ switch (solvingFor) {
+ case SOLVING_FOR_RHS: {
+ concatTerms.push_back(lhsTerm);
+ // Maybe we need to add more lhs to make them equal
+ while (sizeDifference < 0 && !reslice) {
+ Assert(lhs.size() > 0);
+ // Get the next part for lhs
+ lhsTerm = lhs.back();
+ lhs.pop_back();
+ // Slice if necessary
+ if (!isSliced(lhsTerm)) {
+ if (!slice(lhsTerm, lhs)) return false;
+ continue;
+ }
+ // If we go above 0, we need to cut it
+ if (sizeDifference + (int)utils::getSize(lhsTerm) > 0) {
+ // Slice it so it fits
+ addSlice(lhsTerm, (int)utils::getSize(lhsTerm) + sizeDifference);
+ if (!slice(lhsTerm, lhs)) return false;
+ if (!isSliced(rhsTerm)) {
+ if (!slice(rhsTerm, rhs)) return false;
+ while(!concatTerms.empty()) {
+ lhs.push_back(concatTerms.back());
+ concatTerms.pop_back();
+ }
+ reslice = true;
+ }
+ continue;
+ }
+ concatTerms.push_back(lhsTerm);
+ sizeDifference += utils::getSize(lhsTerm);
+ }
+ break;
+ }
+ case SOLVING_FOR_LHS: {
+ concatTerms.push_back(rhsTerm);
+ // Maybe we need to add more rhs to make them equal
+ while (sizeDifference > 0 && !reslice) {
+ Assert(rhs.size() > 0);
+ // Get the next part for lhs
+ rhsTerm = rhs.back();
+ rhs.pop_back();
+ // Slice if necessary
+ if (!isSliced(rhsTerm)) {
+ if (!slice(rhsTerm, rhs)) return false;
+ continue;
+ }
+ // If we go below 0, we need to cut it
+ if (sizeDifference - (int)utils::getSize(rhsTerm) < 0) {
+ // Slice it so it fits
+ addSlice(rhsTerm, (int)utils::getSize(rhsTerm) - sizeDifference);
+ if (!slice(rhsTerm, rhs)) return false;
+ if (!isSliced(lhsTerm)) {
+ if (!slice(lhsTerm, lhs)) return false;
+ while(!concatTerms.empty()) {
+ rhs.push_back(concatTerms.back());
+ concatTerms.pop_back();
+ }
+ reslice = true;
+ }
+ continue;
+ }
+ concatTerms.push_back(rhsTerm);
+ sizeDifference -= utils::getSize(rhsTerm);
+ }
+ break;
+ }
+ }
+
+ // If we need to reslice
+ if (reslice) {
+ continue;
+ }
+
+ Assert(sizeDifference == 0);
+
+ Node concat = utils::mkConcat(concatTerms);
+ Debug("slicing") << "SliceManager::sliceAndSolve(): concatenation " << concat << std::endl;
+
+ // We have them equal size now. If the base term of the one we are solving is solved into a
+ // non-trivial concatenation already, we have to normalize. A concatenation is non-trivial if
+ // it is not a direct slicing, i.e it is a concat, and normalize(x) != x
+ switch (solvingFor) {
+ case SOLVING_FOR_LHS: {
+ TNode lhsTermRepresentative = d_equalityEngine.getRepresentative(lhsTerm);
+ if (lhsTermRepresentative != lhsTerm &&
+ (lhsTermRepresentative.getKind() == kind::BITVECTOR_CONCAT || lhsTermRepresentative.getKind() == kind::CONST_BITVECTOR)) {
+ // We need to normalize and solve the normalized equations
+ std::vector<TNode> explanation;
+ d_equalityEngine.getExplanation(lhsTerm, lhsTermRepresentative, explanation);
+ std::set<TNode> additionalAssumptions(assumptions);
+ additionalAssumptions.insert(explanation.begin(), explanation.end());
+ bool ok = solveEquality(lhsTermRepresentative, concat, additionalAssumptions);
+ if (!ok) return false;
} else {
- // x[i_1:j_1] vs y[i_2:j_2]
+ // We're fine, just add the equality
+ Debug("slicing") << "SliceManager::sliceAndSolve(): adding " << lhsTerm << " = " << concat << " " << utils::setToString(assumptions) << std::endl;
+ d_equalityEngine.addTerm(concat);
+ bool ok = d_equalityEngine.addEquality(lhsTerm, concat, utils::mkConjunction(assumptions));
+ if (!ok) return false;
}
- lhsSlices.push_back(lhsTerm);
- rhsSlices.push_back(rhsTerm);
- continue;
- } else {
- // They are not of equal sizes, so we slice one
- if (lhsSize < rhsSize) {
- // We need to cut a piece of rhs
+ break;
+ }
+ case SOLVING_FOR_RHS: {
+ TNode rhsTermRepresentative = d_equalityEngine.getRepresentative(rhsTerm);
+ if (rhsTermRepresentative != rhsTerm &&
+ (rhsTermRepresentative.getKind() == kind::BITVECTOR_CONCAT || rhsTermRepresentative.getKind() == kind::CONST_BITVECTOR)) {
+ // We need to normalize and solve the normalized equations
+ std::vector<TNode> explanation;
+ d_equalityEngine.getExplanation(rhsTerm, rhsTermRepresentative, explanation);
+ std::set<TNode> additionalAssumptions(assumptions);
+ additionalAssumptions.insert(explanation.begin(), explanation.end());
+ bool ok = solveEquality(rhsTermRepresentative, concat, additionalAssumptions);
+ if (!ok) return false;
} else {
- // We need to cut a piece of lhs
+ // We're fine, just add the equality
+ Debug("slicing") << "SliceManager::sliceAndSolve(): adding " << rhsTerm << " = " << concat << utils::setToString(assumptions) << std::endl;
+ d_equalityEngine.addTerm(concat);
+ bool ok = d_equalityEngine.addEquality(rhsTerm, concat, utils::mkConjunction(assumptions));
+ if (!ok) return false;
}
+ break;
+ }
}
}
+
+ return true;
}
template <class TheoryBitvector>
@@ -315,9 +541,11 @@ bool SliceManager<TheoryBitvector>::isSliced(TNode node) const {
if (find == d_nodeSlicing.end()) {
result = nodeKind != kind::BITVECTOR_EXTRACT;
} else {
- // Check whether there is a slice point in [high, low), if there is the term is not sliced.
- // Hence, if we look for the upper bound of low, and it is higher than high, it is sliced.
- result = d_setCollection.count(find->second, low + 1, high) > 0;
+ // The term is not sliced if one of the borders is not in the slice set or
+ // there is a point between the borders
+ result =
+ d_setCollection.contains(find->second, low) && d_setCollection.contains(find->second, high + 1) &&
+ (low == high || d_setCollection.count(find->second, low + 1, high) == 0);
}
}
@@ -326,7 +554,7 @@ bool SliceManager<TheoryBitvector>::isSliced(TNode node) const {
}
template <class TheoryBitvector>
-inline void SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>& sliced) {
+inline bool SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>& sliced) {
Debug("slicing") << "SliceManager::slice(" << node << ")" << std::endl;
@@ -335,44 +563,64 @@ inline void SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>&
// The indices of the beginning and (one past) end
unsigned high = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractHigh(node) + 1 : utils::getSize(node);
unsigned low = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractLow(node) : 0;
+ Debug("slicing") << "SliceManager::slice(" << node << "): low: " << low << std::endl;
+ Debug("slicing") << "SliceManager::slice(" << node << "): high: " << high << std::endl;
// Get the base term
TNode nodeBase = baseTerm(node);
Assert(nodeBase.getKind() != kind::BITVECTOR_CONCAT);
Assert(nodeBase.getKind() != kind::CONST_BITVECTOR);
- // Get the base term slice set
- set_collection::reference_type nodeSliceSet = d_nodeSlicing[nodeBase];
+ // The nodes slice set
+ set_collection::reference_type nodeSliceSet;
+
+ // Find the current one or construct it
+ slicing_map::iterator findSliceSet = d_nodeSlicing.find(nodeBase);
+ if (findSliceSet == d_nodeSlicing.end()) {
+ nodeSliceSet = d_setCollection.newSet(utils::getSize(nodeBase));
+ d_setCollection.insert(nodeSliceSet, 0);
+ d_nodeSlicing[nodeBase] = nodeSliceSet;
+ } else {
+ nodeSliceSet = d_nodeSlicing[nodeBase];
+ }
+
Debug("slicing") << "SliceManager::slice(" << node << "): current: " << d_setCollection.toString(nodeSliceSet) << std::endl;
std::vector<size_t> slicePoints;
- d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints);
-
+ if (low + 1 < high) {
+ d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints);
+ }
+
// Go through all the points i_0 <= low < i_1 < ... < i_{n-1} < high <= i_n from the slice set
// and generate the slices [i_0:low-1][low:i_1-1] [i_1:i2] ... [i_{n-1}:high-1][high:i_n-1]. They are in reverse order,
// as they should be
size_t i_0 = low == 0 ? 0 : d_setCollection.prev(nodeSliceSet, low + 1);
- size_t i_n = high == utils::getSize(nodeBase) ? high: d_setCollection.next(nodeSliceSet, high);
+ Debug("slicing") << "SliceManager::slice(" << node << "): i_0: " << i_0 << std::endl;
+ size_t i_n = high == utils::getSize(nodeBase) ? high: d_setCollection.next(nodeSliceSet, high - 1);
+ Debug("slicing") << "SliceManager::slice(" << node << "): i_n: " << i_n << std::endl;
// Add the new points to the slice set (they might be there already)
if (high < i_n) {
- std::vector<Node> lastTwoSlices;
- lastTwoSlices.push_back(utils::mkExtract(nodeBase, i_n-1, high));
- lastTwoSlices.push_back(utils::mkExtract(nodeBase, high-1, slicePoints.back()));
- d_equalityEngine.addEquality(utils::mkExtract(nodeBase, i_n-1, slicePoints.back()), utils::mkConcat(lastTwoSlices));
+ if (!addSlice(nodeBase, high)) return false;
}
-
- while (!slicePoints.empty()) {
+ // Construct the actuall slicing
+ if (slicePoints.size() > 0) {
+ Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, slicePoints[0] - 1, low) << std::endl;
+ sliced.push_back(utils::mkExtract(nodeBase, slicePoints[0] - 1, low));
+ for (unsigned i = 1; i < slicePoints.size(); ++ i) {
+ Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, slicePoints[i] - 1, slicePoints[i-1])<< std::endl;
+ sliced.push_back(utils::mkExtract(nodeBase, slicePoints[i] - 1, slicePoints[i-1]));
+ }
+ Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, high-1, slicePoints.back()) << std::endl;
sliced.push_back(utils::mkExtract(nodeBase, high-1, slicePoints.back()));
- high = slicePoints.back();
- slicePoints.pop_back();
+ } else {
+ sliced.push_back(utils::mkExtract(nodeBase, high - 1, low));
}
-
+ // Add the new points to the slice set (they might be there already)
if (i_0 < low) {
- std::vector<Node> firstTwoSlices;
- firstTwoSlices.push_back(utils::mkExtract(nodeBase, high-1, low));
- firstTwoSlices.push_back(utils::mkExtract(nodeBase, low-1, i_0));
- d_equalityEngine.addEquality(utils::mkExtract(nodeBase, high-1, i_0), utils::mkConcat(firstTwoSlices));
+ if (!addSlice(nodeBase, low)) return false;
}
+
+ return true;
}
template <class TheoryBitvector>
diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp
index e183a592c..2d823383c 100644
--- a/src/theory/bv/theory_bv.cpp
+++ b/src/theory/bv/theory_bv.cpp
@@ -35,7 +35,17 @@ void TheoryBV::preRegisterTerm(TNode node) {
if (node.getKind() == kind::EQUAL) {
d_eqEngine.addTerm(node[0]);
+ if (node[0].getKind() == kind::BITVECTOR_CONCAT) {
+ for (unsigned i = 0, i_end = node[0].getNumChildren(); i < i_end; ++ i) {
+ d_eqEngine.addTerm(node[0][i]);
+ }
+ }
d_eqEngine.addTerm(node[1]);
+ if (node[1].getKind() == kind::BITVECTOR_CONCAT) {
+ for (unsigned i = 0, i_end = node[1].getNumChildren(); i < i_end; ++ i) {
+ d_eqEngine.addTerm(node[1][i]);
+ }
+ }
size_t triggerId = d_eqEngine.addTrigger(node[0], node[1]);
Assert(triggerId == d_triggers.size());
d_triggers.push_back(node);
@@ -57,29 +67,27 @@ void TheoryBV::check(Effort e) {
// Do the right stuff
switch (assertion.getKind()) {
case kind::EQUAL: {
-
- // Slice the equality
- std::vector<Node> lhsSlices, rhsSlices;
- d_sliceManager.addEquality(assertion[0], assertion[1], lhsSlices, rhsSlices);
- Assert(lhsSlices.size() == rhsSlices.size());
-
- // Add the equality to the equality engine
- for (int i = 0, i_end = lhsSlices.size(); i != i_end; ++ i) {
- bool ok = d_eqEngine.addEquality(lhsSlices[i], rhsSlices[i]);
- if (!ok) return;
- }
+ // Slice and solve the equality
+ bool ok = d_sliceManager.solveEquality(assertion[0], assertion[1]);
+ if (!ok) return;
break;
}
case kind::NOT: {
// We need to check this as the equality trigger might have been true when we made it
TNode equality = assertion[0];
+ // Assumptions
+ std::set<TNode> assumptions;
+ Node lhsNormalized = d_eqEngine.normalize(equality[0], assumptions);
+ Node rhsNormalized = d_eqEngine.normalize(equality[1], assumptions);
+
+ Debug("bitvector") << "TheoryBV::check(" << e << "): normalizes to " << lhsNormalized << " = " << rhsNormalized << std::endl;
+
// No need to slice the equality, the whole thing *should* be deduced
- if (d_eqEngine.areEqual(equality[0], equality[1])) {
- vector<TNode> assertions;
- d_eqEngine.getExplanation(equality[0], equality[1], assertions);
- assertions.push_back(assertion);
- d_out->conflict(mkAnd(assertions));
+ if (lhsNormalized == rhsNormalized) {
+ Debug("bitvector") << "TheoryBV::check(" << e << "): conflict with " << utils::setToString(assumptions) << std::endl;
+ assumptions.insert(assertion);
+ d_out->conflict(mkConjunction(assumptions));
return;
}
break;
diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h
index ede98004f..fa9762cb7 100644
--- a/src/theory/bv/theory_bv.h
+++ b/src/theory/bv/theory_bv.h
@@ -46,7 +46,34 @@ public:
}
};
- typedef EqualityEngine<TheoryBV, EqualityNotify, true, true> BvEqualityEngine;
+ struct BVEqualitySettings {
+ static inline bool descend(TNode node) {
+ return node.getKind() == kind::BITVECTOR_CONCAT || node.getKind() == kind::BITVECTOR_EXTRACT;
+ }
+
+ /** Returns true if node1 has preference to node2 as a representative, otherwise node2 is used */
+ static inline bool mergePreference(TNode node1, unsigned node1size, TNode node2, unsigned node2size) {
+ if (node1.getKind() == kind::CONST_BITVECTOR) {
+ Assert(node2.getKind() != kind::CONST_BITVECTOR);
+ return true;
+ }
+ if (node2.getKind() == kind::CONST_BITVECTOR) {
+ Assert(node1.getKind() != kind::CONST_BITVECTOR);
+ return false;
+ }
+ if (node1.getKind() == kind::BITVECTOR_CONCAT) {
+ Assert(node2.getKind() != kind::BITVECTOR_CONCAT);
+ return true;
+ }
+ if (node2.getKind() == kind::BITVECTOR_CONCAT) {
+ Assert(node1.getKind() != kind::BITVECTOR_CONCAT);
+ return false;
+ }
+ return node2size < node1size;
+ }
+ };
+
+ typedef EqualityEngine<TheoryBV, EqualityNotify, BVEqualitySettings> BvEqualityEngine;
private:
diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp
index 9b545d25a..9b5c8b0f9 100644
--- a/src/theory/bv/theory_bv_rewriter.cpp
+++ b/src/theory/bv/theory_bv_rewriter.cpp
@@ -51,12 +51,12 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) {
break;
case kind::BITVECTOR_EXTRACT:
result = LinearRewriteStrategy<
+ // Extract over a concatenation is distributed to the appropriate concatenations
+ RewriteRule<ExtractConcat>,
// Extract over a constant gives a constant
RewriteRule<ExtractConstant>,
- // Extract over an extract is simplified to one extract
+ // We could get another extract over extract
RewriteRule<ExtractExtract>,
- // Extract over a concatenation is distributed to the appropriate concatenations
- RewriteRule<ExtractConcat>,
// At this point only Extract-Whole could apply
RewriteRule<ExtractWhole>
>::apply(node);
diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h
index 6e9dbef3e..ad924f8a0 100644
--- a/src/theory/bv/theory_bv_utils.h
+++ b/src/theory/bv/theory_bv_utils.h
@@ -19,7 +19,9 @@
#pragma once
+#include <set>
#include <vector>
+#include <sstream>
#include "expr/node_manager.h"
namespace CVC4 {
@@ -51,6 +53,10 @@ inline Node mkAnd(std::vector<TNode>& children) {
return NodeManager::currentNM()->mkNode(kind::AND, children);
}
+inline Node mkAnd(std::vector<Node>& children) {
+ return NodeManager::currentNM()->mkNode(kind::AND, children);
+}
+
inline Node mkExtract(TNode node, unsigned high, unsigned low) {
Node extractOp = NodeManager::currentNM()->mkConst<BitVectorExtract>(BitVectorExtract(high, low));
std::vector<Node> children;
@@ -65,10 +71,92 @@ inline Node mkConcat(std::vector<Node>& children) {
return children[0];
}
+inline Node mkConcat(TNode t1, TNode t2) {
+ return NodeManager::currentNM()->mkNode(kind::BITVECTOR_CONCAT, t1, t2);
+}
+
+
inline Node mkConst(const BitVector& value) {
return NodeManager::currentNM()->mkConst<BitVector>(value);
}
+inline void getConjuncts(TNode node, std::set<TNode>& conjuncts) {
+ if (node.getKind() != kind::AND) {
+ conjuncts.insert(node);
+ } else {
+ for (unsigned i = 0; i < node.getNumChildren(); ++ i) {
+ getConjuncts(node[i], conjuncts);
+ }
+ }
+}
+
+inline Node mkConjunction(const std::set<TNode> nodes) {
+ std::set<TNode> expandedNodes;
+
+ std::set<TNode>::const_iterator it = nodes.begin();
+ std::set<TNode>::const_iterator it_end = nodes.end();
+ while (it != it_end) {
+ TNode current = *it;
+ if (current != mkTrue()) {
+ Assert(current != mkFalse());
+ if (current.getKind() == kind::AND) {
+ getConjuncts(current, expandedNodes);
+ } else {
+ expandedNodes.insert(current);
+ }
+ }
+ ++ it;
+ }
+
+ Assert(expandedNodes.size() > 0);
+ if (expandedNodes.size() == 1) {
+ return *expandedNodes.begin();
+ }
+
+ NodeBuilder<> conjunction(kind::AND);
+
+ it = expandedNodes.begin();
+ it_end = expandedNodes.end();
+ while (it != it_end) {
+ conjunction << *it;
+ ++ it;
+ }
+
+ return conjunction;
+}
+
+// Turn a set into a string
+inline std::string setToString(const std::set<TNode>& nodeSet) {
+ std::stringstream out;
+ out << "[";
+ std::set<TNode>::const_iterator it = nodeSet.begin();
+ std::set<TNode>::const_iterator it_end = nodeSet.end();
+ bool first = true;
+ while (it != it_end) {
+ if (!first) {
+ out << ",";
+ }
+ first = false;
+ out << *it;
+ ++ it;
+ }
+ out << "]";
+ return out.str();
+}
+
+// Turn a vector into a string
+inline std::string vectorToString(const std::vector<Node>& nodes) {
+ std::stringstream out;
+ out << "[";
+ for (unsigned i = 0; i < nodes.size(); ++ i) {
+ if (i > 0) {
+ out << ",";
+ }
+ out << nodes[i];
+ }
+ out << "]";
+ return out.str();
+}
}
}
diff --git a/src/util/bitvector.h b/src/util/bitvector.h
index d1bfafb00..ca69fb506 100644
--- a/src/util/bitvector.h
+++ b/src/util/bitvector.h
@@ -98,7 +98,7 @@ public:
return BitVector(d_size + other.d_size, (d_value * Integer(2).pow(other.d_size)) + other.d_value);
}
- BitVector extract(unsigned high, unsigned low) {
+ BitVector extract(unsigned high, unsigned low) const {
return BitVector(high - low + 1, (d_value % (Integer(2).pow(high + 1))) / Integer(2).pow(low));
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback