diff options
Diffstat (limited to 'src')
26 files changed, 2523 insertions, 261 deletions
diff --git a/src/expr/type_checker_template.cpp b/src/expr/type_checker_template.cpp index 16f9ba917..4d9cbc60d 100644 --- a/src/expr/type_checker_template.cpp +++ b/src/expr/type_checker_template.cpp @@ -39,6 +39,9 @@ TypeNode TypeChecker::computeType(NodeManager* nodeManager, TNode n, bool check) case kind::BUILTIN: typeNode = nodeManager->builtinOperatorType(); break; + case kind::BITVECTOR_EXTRACT_OP : + typeNode = nodeManager->builtinOperatorType(); + break; ${typerules} diff --git a/src/theory/bv/Makefile.am b/src/theory/bv/Makefile.am index ab6382770..a70521791 100644 --- a/src/theory/bv/Makefile.am +++ b/src/theory/bv/Makefile.am @@ -13,12 +13,18 @@ libbv_la_SOURCES = \ bitblaster.h \ bitblaster.cpp \ bv_subtheory.h \ - bv_subtheory_eq.h \ - bv_subtheory_eq.cpp \ + bv_subtheory_core.h \ + bv_subtheory_core.cpp \ bv_subtheory_bitblast.h \ bv_subtheory_bitblast.cpp \ + bv_subtheory_inequality.h \ + bv_subtheory_inequality.cpp \ + bv_inequality_graph.h \ + bv_inequality_graph.cpp \ bitblast_strategies.h \ bitblast_strategies.cpp \ + slicer.h \ + slicer.cpp \ theory_bv.h \ theory_bv.cpp \ theory_bv_rewrite_rules.h \ diff --git a/src/theory/bv/bitblast_strategies.cpp b/src/theory/bv/bitblast_strategies.cpp index a952b2929..773685997 100644 --- a/src/theory/bv/bitblast_strategies.cpp +++ b/src/theory/bv/bitblast_strategies.cpp @@ -346,7 +346,7 @@ void DefaultVarBB (TNode node, Bits& bits, Bitblaster* bb) { Debug("bitvector-bb") << " with bits " << toString(bits); } - bb->storeVariable(node); + bb->storeVariable(node); } void DefaultConstBB (TNode node, Bits& bits, Bitblaster* bb) { diff --git a/src/theory/bv/bitblaster.cpp b/src/theory/bv/bitblaster.cpp index 4f5325e10..cc589c5c3 100644 --- a/src/theory/bv/bitblaster.cpp +++ b/src/theory/bv/bitblaster.cpp @@ -412,6 +412,23 @@ bool Bitblaster::isSharedTerm(TNode node) { return d_bv->d_sharedTermsSet.find(node) != d_bv->d_sharedTermsSet.end(); } +bool Bitblaster::hasValue(TNode a) { + Assert (d_termCache.find(a) != d_termCache.end()); + Bits bits = d_termCache[a]; + for (int i = bits.size() -1; i >= 0; --i) { + SatValue bit_value; + if (d_cnfStream->hasLiteral(bits[i])) { + SatLiteral bit = d_cnfStream->getLiteral(bits[i]); + bit_value = d_satSolver->value(bit); + if (bit_value == SAT_VALUE_UNKNOWN) + return false; + } else { + return false; + } + } + return true; +} + Node Bitblaster::getVarValue(TNode a) { Assert (d_termCache.find(a) != d_termCache.end()); Bits bits = d_termCache[a]; @@ -436,7 +453,7 @@ void Bitblaster::collectModelInfo(TheoryModel* m) { __gnu_cxx::hash_set<TNode, TNodeHashFunction>::iterator it = d_variables.begin(); for (; it!= d_variables.end(); ++it) { TNode var = *it; - if (Theory::theoryOf(var) == theory::THEORY_BV || isSharedTerm(var)) { + if ((Theory::theoryOf(var) == theory::THEORY_BV || isSharedTerm(var)) && hasValue(var)) { Node const_value = getVarValue(var); Debug("bitvector-model") << "Bitblaster::collectModelInfo (assert (= " << var << " " diff --git a/src/theory/bv/bitblaster.h b/src/theory/bv/bitblaster.h index 21b389508..84a67e884 100644 --- a/src/theory/bv/bitblaster.h +++ b/src/theory/bv/bitblaster.h @@ -124,7 +124,8 @@ class Bitblaster { // division is bitblasted in terms of constraints // so it needs to use private bitblaster interface void bbUdiv(TNode node, Bits& bits); - void bbUrem(TNode node, Bits& bits); + void bbUrem(TNode node, Bits& bits); + bool hasValue(TNode a); public: void cacheTermDef(TNode node, Bits def); // public so we can cache remainder for division void bbTerm(TNode node, Bits& bits); @@ -164,9 +165,9 @@ public: } bool isSharedTerm(TNode node); -private: - +private: + class Statistics { public: IntStat d_numTermClauses, d_numAtomClauses; diff --git a/src/theory/bv/bv_inequality_graph.cpp b/src/theory/bv/bv_inequality_graph.cpp new file mode 100644 index 000000000..e29ce2014 --- /dev/null +++ b/src/theory/bv/bv_inequality_graph.cpp @@ -0,0 +1,425 @@ +/********************* */ +/*! \file bv_inequality_graph.cpp + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009-2012 New York University and The University of Iowa + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief A graph representation of the currently asserted bv inequalities. + ** + ** A graph representation of the currently asserted bv inequalities. + **/ + +#include "theory/bv/bv_inequality_graph.h" +#include "theory/bv/theory_bv_utils.h" + +using namespace std; +using namespace CVC4; +using namespace CVC4::context; +using namespace CVC4::theory; +using namespace CVC4::theory::bv; +using namespace CVC4::theory::bv::utils; + +const TermId CVC4::theory::bv::UndefinedTermId = -1; +const ReasonId CVC4::theory::bv::UndefinedReasonId = -1; +const ReasonId CVC4::theory::bv::AxiomReasonId = -2; + +BitVector InequalityGraph::maxValue(unsigned bitwidth) { + if (d_signed) { + return BitVector(1, 0u).concat(~BitVector(bitwidth - 1, 0u)); + } + return ~BitVector(bitwidth, 0u); +} + +BitVector InequalityGraph::minValue(unsigned bitwidth) { + if (d_signed) { + return ~BitVector(bitwidth, 0u); + } + return BitVector(bitwidth, 0u); +} + +TermId InequalityGraph::getMaxValueId(unsigned bitwidth) { + BitVector bv = maxValue(bitwidth); + Node max = utils::mkConst(bv); + + if (d_termNodeToIdMap.find(max) == d_termNodeToIdMap.end()) { + TermId id = d_termNodes.size(); + d_termNodes.push_back(max); + d_termNodeToIdMap[max] = id; + InequalityNode node(id, bitwidth, true, bv); + d_ineqNodes.push_back(node); + + // although it will never have out edges we need this to keep the size of + // d_termNodes and d_ineqEdges in sync + d_ineqEdges.push_back(Edges()); + return id; + } + return d_termNodeToIdMap[max]; +} + +TermId InequalityGraph::getMinValueId(unsigned bitwidth) { + BitVector bv = minValue(bitwidth); + Node min = utils::mkConst(bv); + + if (d_termNodeToIdMap.find(min) == d_termNodeToIdMap.end()) { + TermId id = d_termNodes.size(); + d_termNodes.push_back(min); + d_termNodeToIdMap[min] = id; + d_ineqEdges.push_back(Edges()); + InequalityNode node = InequalityNode(id, bitwidth, true, bv); + d_ineqNodes.push_back(node); + return id; + } + return d_termNodeToIdMap[min]; +} + +bool InequalityGraph::addInequality(TNode a, TNode b, bool strict, TNode reason) { + Debug("bv-inequality") << "InequlityGraph::addInequality " << a << " " << b << "\n"; + + TermId id_a = registerTerm(a); + TermId id_b = registerTerm(b); + ReasonId id_reason = registerReason(reason); + + Assert (!(isConst(id_a) && isConst(id_b))); + BitVector a_val = getValue(id_a); + BitVector b_val = getValue(id_b); + + unsigned bitwidth = utils::getSize(a); + BitVector diff = strict ? BitVector(bitwidth, 1u) : BitVector(bitwidth, 0u); + if (a_val + diff <= b_val) { + // the inequality is true in the current partial model + // we still add the edge because it may not be true later (cardinality) + addEdge(id_a, id_b, strict, id_reason); + return true; + } + + if (isConst(id_b) && a_val + diff > b_val) { + // we must be in a conflict since a has the minimum value that + // satisifes the constraints + std::vector<ReasonId> conflict; + conflict.push_back(id_reason); + computeExplanation(UndefinedTermId, id_a, conflict); + Debug("bv-inequality") << "InequalityGraph::addInequality conflict: constant UB \n"; + setConflict(conflict); + return false; + } + + // add the inequality edge + addEdge(id_a, id_b, strict, id_reason); + BFSQueue queue; + queue.push(PQueueElement(id_a, getValue(id_a), getValue(id_a), + (hasExplanation(id_a) ? getExplanation(id_a) : Explanation()))); + TermIdSet seen; + return computeValuesBFS(queue, id_a, seen); +} + +bool InequalityGraph::updateValue(const PQueueElement& el, TermId start, const TermIdSet& seen) { + TermId id = el.id; + const BitVector& lower_bound = el.lower_bound; + InequalityNode& ineqNode = getInequalityNode(id); + + if (ineqNode.isConstant()) { + if (ineqNode.getValue() < lower_bound) { + Debug("bv-inequality") << "Conflict: constant " << ineqNode.getValue() << "\n"; + std::vector<ReasonId> conflict; + TermId parent = el.explanation.parent; + ReasonId reason = el.explanation.reason; + conflict.push_back(reason); + computeExplanation(UndefinedTermId, parent, conflict); + Debug("bv-inequality") << "InequalityGraph::addInequality conflict: constant\n"; + setConflict(conflict); + return false; + } + } else { + // if not constant we can update the value + if (ineqNode.getValue() < lower_bound) { + // if we are updating the term we started with we must be in a cycle + if (seen.count(id)) { + TermId parent = el.explanation.parent; + ReasonId reason = el.explanation.reason; + std::vector<TermId> conflict; + conflict.push_back(reason); + computeExplanation(id, parent, conflict); + Debug("bv-inequality") << "InequalityGraph::addInequality conflict: cycle \n"; + setConflict(conflict); + return false; + } + Debug("bv-inequality-internal") << "Updating " << getTermNode(id) + << " from " << ineqNode.getValue() << "\n" + << " to " << lower_bound << "\n"; + ineqNode.setValue(lower_bound); + setExplanation(id, el.explanation); + } + } + return true; +} + +bool InequalityGraph::computeValuesBFS(BFSQueue& queue, TermId start, TermIdSet& seen) { + if (queue.empty()) + return true; + + const PQueueElement current = queue.top(); + queue.pop(); + + if (!updateValue(current, start, seen)) { + return false; + } + if (seen.count(current.id) && current.id != getMaxValueId(getBitwidth(current.id))) { + Debug("bv-inequality-internal") << "InequalityGraph::computeValuesBFS equal cycle."; + // this means we are in a cycle where all the values are forced to be equal + return computeValuesBFS(queue, start, seen); + } + + seen.insert(current.id); + const BitVector& current_value = getValue(current.id); + + unsigned size = getBitwidth(current.id); + const BitVector zero(size, 0u); + const BitVector one(size, 1u); + + const Edges& edges = getEdges(current.id); + for (Edges::const_iterator it = edges.begin(); it!= edges.end(); ++it) { + TermId next = it->next; + const BitVector increment = it->strict ? one : zero; + const BitVector& next_lower_bound = current_value + increment; + const BitVector& value = getValue(next); + queue.push(PQueueElement(next, value, next_lower_bound, Explanation(current.id, it->reason))); + } + return computeValuesBFS(queue, start, seen); +} + +void InequalityGraph::computeExplanation(TermId from, TermId to, std::vector<ReasonId>& explanation) { + while(hasExplanation(to) && from != to) { + const Explanation& exp = getExplanation(to); + Assert (exp.reason != UndefinedReasonId); + explanation.push_back(exp.reason); + + Assert (exp.parent != UndefinedTermId); + to = exp.parent; + } +} + +void InequalityGraph::addEdge(TermId a, TermId b, bool strict, TermId reason) { + Edges& edges = getEdges(a); + edges.push_back(InequalityEdge(b, strict, reason)); +} + +TermId InequalityGraph::registerTerm(TNode term) { + Debug("bv-inequality-internal") << "InequalityGraph::registerTerm " << term << "\n"; + + + if (d_termNodeToIdMap.find(term) != d_termNodeToIdMap.end()) { + return d_termNodeToIdMap[term]; + } + + // store in node mapping + TermId id = d_termNodes.size(); + Debug("bv-inequality-internal") << " with id " << id << "\n"; + + d_termNodes.push_back(term); + d_termNodeToIdMap[term] = id; + + // create InequalityNode + unsigned size = utils::getSize(term); + bool isConst = term.getKind() == kind::CONST_BITVECTOR; + BitVector value = isConst? term.getConst<BitVector>() : minValue(size); + + InequalityNode ineq = InequalityNode(id, size, isConst, value); + Assert (d_ineqNodes.size() == id); + d_ineqNodes.push_back(ineq); + + Assert (d_ineqEdges.size() == id); + d_ineqEdges.push_back(Edges()); + + // add the default edges min <= term <= max + addEdge(getMinValueId(size), id, false, AxiomReasonId); + addEdge(id, getMaxValueId(size), false, AxiomReasonId); + + return id; +} + +ReasonId InequalityGraph::registerReason(TNode reason) { + if (d_reasonToIdMap.find(reason) != d_reasonToIdMap.end()) { + return d_reasonToIdMap[reason]; + } + ReasonId id = d_reasonNodes.size(); + d_reasonNodes.push_back(reason); + d_reasonToIdMap[reason] = id; + return id; +} + +TNode InequalityGraph::getReasonNode(ReasonId id) const { + Assert (d_reasonNodes.size() > id); + return d_reasonNodes[id]; +} + +TNode InequalityGraph::getTermNode(TermId id) const { + Assert (d_termNodes.size() > id); + return d_termNodes[id]; +} + +void InequalityGraph::setConflict(const std::vector<ReasonId>& conflict) { + Assert (!d_inConflict); + d_inConflict = true; + d_conflict.clear(); + for (unsigned i = 0; i < conflict.size(); ++i) { + if (conflict[i] != AxiomReasonId) { + d_conflict.push_back(getReasonNode(conflict[i])); + } + } + if (Debug.isOn("bv-inequality")) { + Debug("bv-inequality") << "InequalityGraph::setConflict \n"; + for (unsigned i = 0; i < d_conflict.size(); ++i) { + Debug("bv-inequality") << " " << d_conflict[i] <<"\n"; + } + } +} + +void InequalityGraph::getConflict(std::vector<TNode>& conflict) { + for (unsigned i = 0; i < d_conflict.size(); ++i) { + conflict.push_back(d_conflict[i]); + } +} + +// bool InequalityGraph::initializeValues(TNode a, TNode b, bool strict, TermId reason_id) { +// TermId id_a = registerTerm(a); +// TermId id_b = registerTerm(b); + +// InequalityNode& ineq_a = getInequalityNode(id_a); +// InequalityNode& ineq_b = getInequalityNode(id_b); + +// unsigned size = utils::getSize(a); +// BitVector one = mkOne(size); +// BitVector zero = mkZero(size); +// BitVector diff = strict? one : zero; + +// // FIXME: dumb case splitting +// if (ineq_a.isConstant() && ineq_b.isConstant()) { +// Assert (a.getConst<BitVector>() + diff <= b.getConst<BitVector>()); +// ineq_a.setValue(a.getConst<BitVector>()); +// ineq_b.setValue(b.getConst<BitVector>()); +// return true; +// } + +// if (ineq_a.isConstant()) { +// ineq_a.setValue(a.getConst<BitVector>()); +// } + +// if (ineq_b.isConstant()) { +// const BitVector& const_val = b.getConst<BitVector>(); +// ineq_b.setValue(const_val); + +// if (hasValue(id_a) && !(ineq_a.getValue() + diff <= const_val)) { +// // must be a conflict because we have as an invariant that a will have the min +// // possible value for a. +// std::vector<ReasonId> conflict; +// conflict.push_back(reason_id); +// // FIXME: this will not compute the most precise conflict +// // could be fixed by giving computeExplanation a bound (i.e. the size of const_val) +// computeExplanation(UndefinedTermId, id_a, conflict); +// setConflict(conflict); +// return false; +// } +// } + +// if (!hasValue(id_a) && !hasValue(id_b)) { +// // initialize to the minimum possible values +// if (strict) { +// ineq_a.setValue(MinValue(size)); +// ineq_b.setValue(MinValue(size) + one); +// } else { +// ineq_a.setValue(MinValue(size)); +// ineq_b.setValue(MinValue(size)); +// } +// } +// else if (!hasValue(id_a) && hasValue(id_b)) { +// const BitVector& b_value = ineq_b.getValue(); +// if (strict && b_value == MinValue(size) && ineq_b.isConstant()) { +// Debug("bv-inequality") << "Conflict: underflow " << getTerm(id_a) <<"\n"; +// std::vector<ReasonId> conflict; +// conflict.push_back(reason_id); +// setConflict(conflict); +// return false; +// } +// // otherwise we attempt to increment b +// ineq_b.setValue(one); +// } +// // if a has no value then we can assign it to whatever we want +// // to maintain the invariant that each value has the lowest value +// // we assign it to zero +// ineq_a.setValue(zero); +// } else if (hasValue(id_a) && !hasValue(id_b)) { +// const BitVector& a_value = ineq_a.getValue(); +// if (a_value + one < a_value) { +// return false; +// } +// ineq_b.setValue(a_value + one); +// } +// return true; +// } + +// bool InequalityGraph::canReach(TermId from, TermId to) { +// if (from == to ) +// return true; + +// TermIdSet seen; +// TermIdQueue queue; +// queue.push(from); +// bfs(seen, queue); +// if (seen.count(to)) { +// return true; +// } +// return false; +// } + +// void InequalityGraph::bfs(TermIdSet& seen, TermIdQueue& queue) { +// if (queue.empty()) +// return; + +// TermId current = queue.front(); +// queue.pop(); + +// const Edges& edges = getOutEdges(current); +// for (Edges::const_iterator it = edges.begin(); it!= edges.end(); ++it) { +// TermId next = it->next; +// if(seen.count(next) == 0) { +// seen.insert(next); +// queue.push(next); +// } +// } +// bfs(seen, queue); +// } + +// void InequalityGraph::getPath(TermId to, TermId from, const TermIdSet& seen, std::vector<ReasonId> explanation) { +// // traverse parent edges +// const Edges& out = getOutEdges(to); +// for (Edges::const_iterator it = out.begin(); it != out.end(); ++it) { +// if (seen.find(it->next)) { +// path.push_back(it->reason); +// getPath(it->next, from, seen, path); +// return; +// } +// } +// } + +// TermId InequalityGraph::getMaxParent(TermId id) { +// const Edges& in_edges = getInEdges(id); +// Assert (in_edges.size() != 0); + +// BitVector max(getBitwidth(0), 0u); +// TermId max_id = UndefinedTermId; +// for (Edges::const_iterator it = in_edges.begin(); it!= in_edges.end(); ++it) { +// // Assert (seen.count(it->next)); +// const BitVector& value = getInequalityNode(it->next).getValue(); +// if (value >= max) { +// max = value; +// max_id = it->next; +// } +// } +// Assert (max_id != UndefinedTermId); +// return max_id; +// } diff --git a/src/theory/bv/bv_inequality_graph.h b/src/theory/bv/bv_inequality_graph.h new file mode 100644 index 000000000..18bd75726 --- /dev/null +++ b/src/theory/bv/bv_inequality_graph.h @@ -0,0 +1,224 @@ +/********************* */ +/*! \file bv_inequality_graph.h + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009-2012 New York University and The University of Iowa + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief Algebraic solver. + ** + ** Algebraic solver. + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__BV__BV_INEQUALITY__GRAPH_H +#define __CVC4__THEORY__BV__BV_INEQUALITY__GRAPH_H + +#include "context/context.h" +#include "context/cdqueue.h" +#include "theory/uf/equality_engine.h" +#include "theory/theory.h" +#include <queue> + +namespace CVC4 { +namespace theory { + + +namespace bv { + +typedef unsigned TermId; +typedef unsigned ReasonId; +extern const TermId UndefinedTermId; +extern const ReasonId UndefinedReasonId; +extern const ReasonId AxiomReasonId; + +class InequalityGraph { + + + context::Context* d_context; + + struct InequalityEdge { + TermId next; + ReasonId reason; + bool strict; + InequalityEdge(TermId n, bool s, ReasonId r) + : next(n), + reason(r), + strict(s) + {} + }; + + class InequalityNode { + TermId d_id; + unsigned d_bitwidth; + bool d_isConstant; + BitVector d_value; + public: + InequalityNode(TermId id, unsigned bitwidth, bool isConst, const BitVector val) + : d_id(id), + d_bitwidth(bitwidth), + d_isConstant(isConst), + d_value(val) {} + TermId getId() const { return d_id; } + unsigned getBitwidth() const { return d_bitwidth; } + bool isConstant() const { return d_isConstant; } + const BitVector& getValue() const { return d_value; } + void setValue(const BitVector& val) { Assert (val.getSize() == d_bitwidth); d_value = val; } + }; + + struct Explanation { + TermId parent; + ReasonId reason; + + Explanation() + : parent(UndefinedTermId), + reason(UndefinedReasonId) + {} + + Explanation(TermId p, ReasonId r) + : parent(p), + reason(r) + {} + }; + + struct PQueueElement { + TermId id; + BitVector value; + BitVector lower_bound; + Explanation explanation; + PQueueElement(TermId id, const BitVector v, const BitVector& lb, Explanation exp) + : id(id), + value(v), + lower_bound(lb), + explanation(exp) + {} + + bool operator< (const PQueueElement& other) const { + return value > other.value; + } + }; + + typedef __gnu_cxx::hash_map<TNode, ReasonId, TNodeHashFunction> ReasonToIdMap; + typedef __gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction> TermNodeToIdMap; + + typedef std::vector<InequalityEdge> Edges; + typedef __gnu_cxx::hash_set<TermId> TermIdSet; + + typedef std::priority_queue<PQueueElement> BFSQueue; + typedef __gnu_cxx::hash_map<TermId, Explanation> TermIdToExplanationMap; + + std::vector<InequalityNode> d_ineqNodes; + std::vector< Edges > d_ineqEdges; + + std::vector<TNode> d_reasonNodes; + ReasonToIdMap d_reasonToIdMap; + + std::vector<Node> d_termNodes; + TermNodeToIdMap d_termNodeToIdMap; + + TermIdToExplanationMap d_termToExplanation; + + context::CDO<bool> d_inConflict; + std::vector<TNode> d_conflict; + bool d_signed; + + + /** + * Registers the term by creating its corresponding InequalityNode + * and adding the min <= term <= max default edges. + * + * @param term + * + * @return + */ + TermId registerTerm(TNode term); + TNode getTermNode(TermId id) const; + TermId getTermId(TNode node) const; + + ReasonId registerReason(TNode reason); + TNode getReasonNode(ReasonId id) const; + + bool hasExplanation(TermId id) const { return d_termToExplanation.find(id) != d_termToExplanation.end(); } + Explanation getExplanation(TermId id) const { Assert (hasExplanation(id)); return d_termToExplanation.find(id)->second; } + void setExplanation(TermId id, Explanation exp) { d_termToExplanation[id] = exp; } + + Edges& getEdges(TermId id) { Assert (id < d_ineqEdges.size()); return d_ineqEdges[id]; } + InequalityNode& getInequalityNode(TermId id) { Assert (id < d_ineqNodes.size()); return d_ineqNodes[id]; } + const InequalityNode& getInequalityNode(TermId id) const { Assert (id < d_ineqNodes.size()); return d_ineqNodes[id]; } + unsigned getBitwidth(TermId id) const { return getInequalityNode(id).getBitwidth(); } + bool isConst(TermId id) const { return getInequalityNode(id).isConstant(); } + BitVector maxValue(unsigned bitwidth); + BitVector minValue(unsigned bitwidth); + TermId getMaxValueId(unsigned bitwidth); + TermId getMinValueId(unsigned bitwidth); + + const BitVector& getValue(TermId id) const { return getInequalityNode(id).getValue(); } + + void addEdge(TermId a, TermId b, bool strict, TermId reason); + + void setConflict(const std::vector<ReasonId>& conflict); + /** + * If necessary update the value in the model of the current queue element. + * + * @param el current queue element we are updating + * @param start node we started with, to detect cycles + * @param seen + * + * @return + */ + bool updateValue(const PQueueElement& el, TermId start, const TermIdSet& seen); + /** + * Update the current model starting with the start term. + * + * @param queue + * @param start + * @param seen + * + * @return + */ + bool computeValuesBFS(BFSQueue& queue, TermId start, TermIdSet& seen); + /** + * Return the reasons why from <= to. If from is undefined we just + * explain the current value of to. + * + * @param from + * @param to + * @param explanation + */ + void computeExplanation(TermId from, TermId to, std::vector<ReasonId>& explanation); + +public: + + InequalityGraph(context::Context* c, bool s = false) + : d_context(c), + d_ineqNodes(), + d_ineqEdges(), + d_inConflict(c, false), + d_conflict(), + d_signed(s) + {} + /** + * + * + * @param a + * @param b + * @param diff + * @param reason + * + * @return + */ + bool addInequality(TNode a, TNode b, bool strict, TNode reason); + bool areLessThan(TNode a, TNode b); + void getConflict(std::vector<TNode>& conflict); +}; + +} +} +} + +#endif /* __CVC4__THEORY__BV__BV_INEQUALITY__GRAPH_H */ diff --git a/src/theory/bv/bv_subtheory.h b/src/theory/bv/bv_subtheory.h index 4dbba0797..c442fa6dd 100644 --- a/src/theory/bv/bv_subtheory.h +++ b/src/theory/bv/bv_subtheory.h @@ -32,8 +32,9 @@ class TheoryModel; namespace bv { enum SubTheory { - SUB_EQUALITY = 1, - SUB_BITBLAST = 2 + SUB_CORE = 1, + SUB_BITBLAST = 2, + SUB_INEQUALITY = 3 }; inline std::ostream& operator << (std::ostream& out, SubTheory subtheory) { @@ -41,9 +42,11 @@ inline std::ostream& operator << (std::ostream& out, SubTheory subtheory) { case SUB_BITBLAST: out << "BITBLASTER"; break; - case SUB_EQUALITY: - out << "EQUALITY"; + case SUB_CORE: + out << "BV_CORE_SUBTHEORY"; break; + case SUB_INEQUALITY: + out << "BV_INEQUALITY_SUBTHEORY"; default: Unreachable(); break; @@ -58,6 +61,7 @@ const bool d_useSatPropagation = true; // forward declaration class TheoryBV; +typedef context::CDQueue<Node> AssertionQueue; /** * Abstract base class for bit-vector subtheory solvers * @@ -71,19 +75,31 @@ protected: /** The bit-vector theory */ TheoryBV* d_bv; - + AssertionQueue d_assertionQueue; + context::CDO<uint32_t> d_assertionIndex; public: SubtheorySolver(context::Context* c, TheoryBV* bv) : d_context(c), - d_bv(bv) + d_bv(bv), + d_assertionQueue(c), + d_assertionIndex(c, 0) {} virtual ~SubtheorySolver() {} + virtual bool check(Theory::Effort e) = 0; + virtual void explain(TNode literal, std::vector<TNode>& assumptions) = 0; + virtual void preRegister(TNode node) {} + virtual void propagate(Theory::Effort e) {} + virtual void collectModelInfo(TheoryModel* m) = 0; + bool done() { return d_assertionQueue.size() == d_assertionIndex; } + TNode get() { + Assert (!done()); + TNode res = d_assertionQueue[d_assertionIndex]; + d_assertionIndex = d_assertionIndex + 1; + return res; + } + void assertFact(TNode fact) { d_assertionQueue.push_back(fact); } - virtual bool addAssertions(const std::vector<TNode>& assertions, Theory::Effort e) = 0; - virtual void explain(TNode literal, std::vector<TNode>& assumptions) = 0; - virtual void preRegister(TNode node) {} - virtual void collectModelInfo(TheoryModel* m) = 0; }; } diff --git a/src/theory/bv/bv_subtheory_bitblast.cpp b/src/theory/bv/bv_subtheory_bitblast.cpp index 501aafb29..2f76e32d3 100644 --- a/src/theory/bv/bv_subtheory_bitblast.cpp +++ b/src/theory/bv/bv_subtheory_bitblast.cpp @@ -52,22 +52,21 @@ void BitblastSolver::explain(TNode literal, std::vector<TNode>& assumptions) { d_bitblaster->explain(literal, assumptions); } -bool BitblastSolver::addAssertions(const std::vector<TNode>& assertions, Theory::Effort e) { - Debug("bitvector::bitblaster") << "BitblastSolver::addAssertions (" << e << ")" << std::endl; +bool BitblastSolver::check(Theory::Effort e) { //// Eager bit-blasting if (options::bitvectorEagerBitblast()) { - for (unsigned i = 0; i < assertions.size(); ++i) { - TNode atom = assertions[i].getKind() == kind::NOT ? assertions[i][0] : assertions[i]; + while (!done()) { + TNode assertion = get(); + TNode atom = assertion.getKind() == kind::NOT ? assertion[0] : assertion; if (atom.getKind() != kind::BITVECTOR_BITOF) { d_bitblaster->bbAtom(atom); } + return true; } - return true; } //// Lazy bit-blasting - // bit-blast enqueued nodes while (!d_bitblastQueue.empty()) { TNode atom = d_bitblastQueue.front(); @@ -75,9 +74,9 @@ bool BitblastSolver::addAssertions(const std::vector<TNode>& assertions, Theory: d_bitblastQueue.pop(); } - // propagation - for (unsigned i = 0; i < assertions.size(); ++i) { - TNode fact = assertions[i]; + // Processinga ssertions + while (!done()) { + TNode fact = get(); if (!d_bv->inConflict() && !d_bv->propagatedBy(fact, SUB_BITBLAST)) { // Some atoms have not been bit-blasted yet d_bitblaster->bbAtom(fact); @@ -103,7 +102,7 @@ bool BitblastSolver::addAssertions(const std::vector<TNode>& assertions, Theory: } } - // solving + // Solving if (e == Theory::EFFORT_FULL || options::bitvectorEagerFullcheck()) { Assert(!d_bv->inConflict()); Debug("bitvector::bitblaster") << "BitblastSolver::addAssertions solving. \n"; diff --git a/src/theory/bv/bv_subtheory_bitblast.h b/src/theory/bv/bv_subtheory_bitblast.h index 3396d813b..318fdd230 100644 --- a/src/theory/bv/bv_subtheory_bitblast.h +++ b/src/theory/bv/bv_subtheory_bitblast.h @@ -42,7 +42,7 @@ public: ~BitblastSolver(); void preRegister(TNode node); - bool addAssertions(const std::vector<TNode>& assertions, Theory::Effort e); + bool check(Theory::Effort e); void explain(TNode literal, std::vector<TNode>& assumptions); EqualityStatus getEqualityStatus(TNode a, TNode b); void collectModelInfo(TheoryModel* m); diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp new file mode 100644 index 000000000..2af0e47b8 --- /dev/null +++ b/src/theory/bv/bv_subtheory_core.cpp @@ -0,0 +1,300 @@ +/********************* */ +/*! \file bv_subtheory_eq.cpp + ** \verbatim + ** Original author: dejan + ** Major contributors: none + ** Minor contributors (to current version): lianah + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009-2012 New York University and The University of Iowa + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief Algebraic solver. + ** + ** Algebraic solver. + **/ + +#include "theory/bv/bv_subtheory_core.h" + +#include "theory/bv/theory_bv.h" +#include "theory/bv/theory_bv_utils.h" +#include "theory/bv/slicer.h" +#include "theory/model.h" + +using namespace std; +using namespace CVC4; +using namespace CVC4::context; +using namespace CVC4::theory; +using namespace CVC4::theory::bv; +using namespace CVC4::theory::bv::utils; + +CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv) + : SubtheorySolver(c, bv), + d_notify(*this), + d_equalityEngine(d_notify, c, "theory::bv::TheoryBV"), + d_slicer(new Slicer(c, this)), + d_isCoreTheory(c, true), + d_reasons(c) +{ + if (d_useEqualityEngine) { + + // The kinds we are treating as function application in congruence + d_equalityEngine.addFunctionKind(kind::BITVECTOR_CONCAT, true); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_AND); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_OR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XOR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NAND); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XNOR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_COMP); + d_equalityEngine.addFunctionKind(kind::BITVECTOR_MULT, true); + d_equalityEngine.addFunctionKind(kind::BITVECTOR_PLUS, true); + d_equalityEngine.addFunctionKind(kind::BITVECTOR_EXTRACT, true); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SUB); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NEG); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UDIV); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UREM); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SDIV); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SREM); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SMOD); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SHL); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_LSHR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ASHR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULE); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGE); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLE); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGE); + } +} + +CoreSolver::~CoreSolver() { + delete d_slicer; +} +void CoreSolver::setMasterEqualityEngine(eq::EqualityEngine* eq) { + d_equalityEngine.setMasterEqualityEngine(eq); +} + +void CoreSolver::preRegister(TNode node) { + if (!d_useEqualityEngine) + return; + + if (node.getKind() == kind::EQUAL) { + d_equalityEngine.addTriggerEquality(node); + // d_slicer->processEquality(node); + } else { + d_equalityEngine.addTerm(node); + } +} + + +void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) { + bool polarity = literal.getKind() != kind::NOT; + TNode atom = polarity ? literal : literal[0]; + if (atom.getKind() == kind::EQUAL) { + d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + } else { + d_equalityEngine.explainPredicate(atom, polarity, assumptions); + } +} + +Node CoreSolver::getBaseDecomposition(TNode a, std::vector<Node>& explanation) { + std::vector<Node> a_decomp; + d_slicer->getBaseDecomposition(a, a_decomp, explanation); + Node new_a = utils::mkConcat(a_decomp); + Debug("bv-slicer") << "CoreSolver::getBaseDecomposition " << a <<" => " << new_a << "\n"; + return new_a; +} + +bool CoreSolver::decomposeFact(TNode fact) { + Debug("bv-slicer") << "CoreSolver::decomposeFact fact=" << fact << endl; + // assert decompositions since the equality engine does not know the semantics of + // concat: + // a == a_1 concat ... concat a_k + // b == b_1 concat ... concat b_k + + if (fact.getKind() == kind::EQUAL) { + TNode a = fact[0]; + TNode b = fact[1]; + + d_slicer->processEquality(fact); + std::vector<Node> explanation; + Node new_a = getBaseDecomposition(a, explanation); + Node new_b = getBaseDecomposition(b, explanation); + + explanation.push_back(fact); + Node reason = utils::mkAnd(explanation); + d_reasons.insert(reason); + + Assert (utils::getSize(new_a) == utils::getSize(new_b) && + utils::getSize(new_a) == utils::getSize(a)); + // FIXME: do we still need to assert these? + NodeManager* nm = NodeManager::currentNM(); + Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a); + Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b); + + d_reasons.insert(a_eq_new_a); + d_reasons.insert(b_eq_new_b); + + bool ok = true; + ok = assertFactToEqualityEngine(a_eq_new_a, utils::mkTrue()); + if (!ok) return false; + ok = assertFactToEqualityEngine(b_eq_new_b, utils::mkTrue()); + if (!ok) return false; + // assert the individual equalities as well + // a_i == b_i + if (new_a.getKind() == kind::BITVECTOR_CONCAT && + new_b.getKind() == kind::BITVECTOR_CONCAT) { + Assert (new_a.getNumChildren() == new_b.getNumChildren()); + for (unsigned i = 0; i < new_a.getNumChildren(); ++i) { + Node eq_i = nm->mkNode(kind::EQUAL, new_a[i], new_b[i]); + ok = assertFactToEqualityEngine(eq_i, reason); + d_reasons.insert(eq_i); + if (!ok) return false; + } + } + // merge the two terms in the slicer as well + d_slicer->assertEquality(fact); + } else { + // still need to register the terms + d_slicer->processEquality(fact[0]); + TNode a = fact[0][0]; + TNode b = fact[0][1]; + std::vector<Node> explanation_a; + Node new_a = getBaseDecomposition(a, explanation_a); + Node reason_a = explanation_a.empty()? mkTrue() : mkAnd(explanation_a); + assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, a, new_a), reason_a); + + std::vector<Node> explanation_b; + Node new_b = getBaseDecomposition(b, explanation_b); + Node reason_b = explanation_b.empty()? mkTrue() : mkAnd(explanation_b); + assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, b, new_b), reason_b); + d_reasons.insert(reason_a); + d_reasons.insert(reason_b); + } + // finally assert the actual fact to the equality engine + return assertFactToEqualityEngine(fact, fact); +} + +bool CoreSolver::check(Theory::Effort e) { + Trace("bitvector::core") << "CoreSolver::check \n"; + Assert (!d_bv->inConflict()); + + bool ok = true; + std::vector<Node> core_eqs; + while (! done()) { + TNode fact = get(); + + // update whether we are in the core fragment + if (d_isCoreTheory && !d_slicer->isCoreTerm(fact)) { + d_isCoreTheory = false; + } + + // only reason about equalities + if (fact.getKind() == kind::EQUAL || (fact.getKind() == kind::NOT && fact[0].getKind() == kind::EQUAL)) { + ok = decomposeFact(fact); + } else { + ok = assertFactToEqualityEngine(fact, fact); + } + if (!ok) + return false; + } + + // make sure to assert the new splits + std::vector<Node> new_splits; + d_slicer->getNewSplits(new_splits); + for (unsigned i = 0; i < new_splits.size(); ++i) { + ok = assertFactToEqualityEngine(new_splits[i], utils::mkTrue()); + if (!ok) + return false; + } + return true; +} + +bool CoreSolver::assertFactToEqualityEngine(TNode fact, TNode reason) { + // Notify the equality engine + if (d_useEqualityEngine && !d_bv->inConflict() && !d_bv->propagatedBy(fact, SUB_CORE) ) { + Debug("bv-slicer-eq") << "CoreSolver::assertFactToEqualityEngine fact=" << fact << endl; + // Debug("bv-slicer-eq") << " reason=" << reason << endl; + bool negated = fact.getKind() == kind::NOT; + TNode predicate = negated ? fact[0] : fact; + if (predicate.getKind() == kind::EQUAL) { + if (negated) { + // dis-equality + d_equalityEngine.assertEquality(predicate, false, reason); + } else { + // equality + d_equalityEngine.assertEquality(predicate, true, reason); + } + } else { + // Adding predicate if the congruence over it is turned on + if (d_equalityEngine.isFunctionKind(predicate.getKind())) { + d_equalityEngine.assertPredicate(predicate, !negated, reason); + } + } + } + + // checking for a conflict + if (d_bv->inConflict()) { + return false; + } + return true; +} + +bool CoreSolver::NotifyClass::eqNotifyTriggerEquality(TNode equality, bool value) { + Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_solver.storePropagation(equality); + } else { + return d_solver.storePropagation(equality.notNode()); + } +} + +bool CoreSolver::NotifyClass::eqNotifyTriggerPredicate(TNode predicate, bool value) { + Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" ) << ")" << std::endl; + if (value) { + return d_solver.storePropagation(predicate); + } else { + return d_solver.storePropagation(predicate.notNode()); + } +} + +bool CoreSolver::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) { + Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << ")" << std::endl; + if (value) { + return d_solver.storePropagation(t1.eqNode(t2)); + } else { + return d_solver.storePropagation(t1.eqNode(t2).notNode()); + } +} + +void CoreSolver::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) { + d_solver.conflict(t1, t2); +} + +bool CoreSolver::storePropagation(TNode literal) { + return d_bv->storePropagation(literal, SUB_CORE); +} + +void CoreSolver::conflict(TNode a, TNode b) { + std::vector<TNode> assumptions; + d_equalityEngine.explainEquality(a, b, true, assumptions); + d_bv->setConflict(mkAnd(assumptions)); +} + +void CoreSolver::collectModelInfo(TheoryModel* m) { + if (Debug.isOn("bitvector-model")) { + context::CDQueue<Node>::const_iterator it = d_assertionQueue.begin(); + for (; it!= d_assertionQueue.end(); ++it) { + Debug("bitvector-model") << "CoreSolver::collectModelInfo (assert " + << *it << ")\n"; + } + } + set<Node> termSet; + d_bv->computeRelevantTerms(termSet); + m->assertEqualityEngine(&d_equalityEngine, &termSet); +} diff --git a/src/theory/bv/bv_subtheory_eq.h b/src/theory/bv/bv_subtheory_core.h index 2b024cfd4..4f2d7a279 100644 --- a/src/theory/bv/bv_subtheory_eq.h +++ b/src/theory/bv/bv_subtheory_core.h @@ -18,23 +18,26 @@ #include "cvc4_private.h" #include "theory/bv/bv_subtheory.h" +#include "context/cdhashmap.h" +#include "context/cdhashset.h" namespace CVC4 { namespace theory { namespace bv { +class Slicer; +class Base; /** * Bitvector equality solver */ -class EqualitySolver : public SubtheorySolver { +class CoreSolver : public SubtheorySolver { // NotifyClass: handles call-back from congruence closure module - class NotifyClass : public eq::EqualityEngineNotify { - EqualitySolver& d_solver; + CoreSolver& d_solver; public: - NotifyClass(EqualitySolver& solver): d_solver(solver) {} + NotifyClass(CoreSolver& solver): d_solver(solver) {} bool eqNotifyTriggerEquality(TNode equality, bool value); bool eqNotifyTriggerPredicate(TNode predicate, bool value); bool eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value); @@ -43,12 +46,12 @@ class EqualitySolver : public SubtheorySolver { void eqNotifyPreMerge(TNode t1, TNode t2) { } void eqNotifyPostMerge(TNode t1, TNode t2) { } void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) { } -}; + }; /** The notify class for d_equalityEngine */ NotifyClass d_notify; - + /** Equality engine */ eq::EqualityEngine d_equalityEngine; @@ -58,14 +61,20 @@ class EqualitySolver : public SubtheorySolver { /** Store a conflict from merging two constants */ void conflict(TNode a, TNode b); - /** FIXME: for debugging purposes only */ - context::CDList<TNode> d_assertions; -public: - - EqualitySolver(context::Context* c, TheoryBV* bv); - void setMasterEqualityEngine(eq::EqualityEngine* eq); + Slicer* d_slicer; + context::CDO<bool> d_isCoreTheory; + /** To make sure we keep the explanations */ + context::CDHashSet<Node, NodeHashFunction> d_reasons; + bool assertFactToEqualityEngine(TNode fact, TNode reason); + bool decomposeFact(TNode fact); + Node getBaseDecomposition(TNode a, std::vector<Node>& explanation); +public: + CoreSolver(context::Context* c, TheoryBV* bv); + ~CoreSolver(); + bool isCoreTheory() { return d_isCoreTheory; } + void setMasterEqualityEngine(eq::EqualityEngine* eq); void preRegister(TNode node); - bool addAssertions(const std::vector<TNode>& assertions, Theory::Effort e); + bool check(Theory::Effort e); void explain(TNode literal, std::vector<TNode>& assumptions); void collectModelInfo(TheoryModel* m); void addSharedTerm(TNode t) { @@ -82,6 +91,8 @@ public: } return EQUALITY_UNKNOWN; } + bool hasTerm(TNode node) const { return d_equalityEngine.hasTerm(node); } + void addTermToEqualityEngine(TNode node) { d_equalityEngine.addTerm(node); } }; diff --git a/src/theory/bv/bv_subtheory_eq.cpp b/src/theory/bv/bv_subtheory_eq.cpp deleted file mode 100644 index 385c2e555..000000000 --- a/src/theory/bv/bv_subtheory_eq.cpp +++ /dev/null @@ -1,185 +0,0 @@ -/********************* */ -/*! \file bv_subtheory_eq.cpp - ** \verbatim - ** Original author: dejan - ** Major contributors: none - ** Minor contributors (to current version): lianah - ** This file is part of the CVC4 prototype. - ** Copyright (c) 2009-2012 New York University and The University of Iowa - ** See the file COPYING in the top-level source directory for licensing - ** information.\endverbatim - ** - ** \brief Algebraic solver. - ** - ** Algebraic solver. - **/ - -#include "theory/bv/bv_subtheory_eq.h" -#include "theory/bv/theory_bv.h" -#include "theory/bv/theory_bv_utils.h" -#include "theory/model.h" - -using namespace std; -using namespace CVC4; -using namespace CVC4::context; -using namespace CVC4::theory; -using namespace CVC4::theory::bv; -using namespace CVC4::theory::bv::utils; - -EqualitySolver::EqualitySolver(context::Context* c, TheoryBV* bv) - : SubtheorySolver(c, bv), - d_notify(*this), - d_equalityEngine(d_notify, c, "theory::bv::TheoryBV"), - d_assertions(c) -{ - if (d_useEqualityEngine) { - - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::BITVECTOR_CONCAT, true); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_AND); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_OR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XOR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NAND); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XNOR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_COMP); - d_equalityEngine.addFunctionKind(kind::BITVECTOR_MULT, true); - d_equalityEngine.addFunctionKind(kind::BITVECTOR_PLUS, true); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SUB); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NEG); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UDIV); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UREM); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SDIV); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SREM); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SMOD); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SHL); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_LSHR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ASHR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULE); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGE); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLE); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGE); - } -} - -void EqualitySolver::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); -} - -void EqualitySolver::preRegister(TNode node) { - if (!d_useEqualityEngine) - return; - - if (node.getKind() == kind::EQUAL) { - d_equalityEngine.addTriggerEquality(node); - } else { - d_equalityEngine.addTerm(node); - } -} - -void EqualitySolver::explain(TNode literal, std::vector<TNode>& assumptions) { - bool polarity = literal.getKind() != kind::NOT; - TNode atom = polarity ? literal : literal[0]; - if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); - } else { - d_equalityEngine.explainPredicate(atom, polarity, assumptions); - } -} - -bool EqualitySolver::addAssertions(const std::vector<TNode>& assertions, Theory::Effort e) { - Trace("bitvector::equality") << "EqualitySolver::addAssertions \n"; - Assert (!d_bv->inConflict()); - - for (unsigned i = 0; i < assertions.size(); ++i) { - TNode fact = assertions[i]; - - // Notify the equality engine - if (d_useEqualityEngine && !d_bv->inConflict() && !d_bv->propagatedBy(fact, SUB_EQUALITY) ) { - Trace("bitvector::equality") << " (assert " << fact << ")\n"; - d_assertions.push_back(fact); - bool negated = fact.getKind() == kind::NOT; - TNode predicate = negated ? fact[0] : fact; - if (predicate.getKind() == kind::EQUAL) { - if (negated) { - // dis-equality - d_equalityEngine.assertEquality(predicate, false, fact); - } else { - // equality - d_equalityEngine.assertEquality(predicate, true, fact); - } - } else { - // Adding predicate if the congruence over it is turned on - if (d_equalityEngine.isFunctionKind(predicate.getKind())) { - d_equalityEngine.assertPredicate(predicate, !negated, fact); - } - } - } - - // checking for a conflict - if (d_bv->inConflict()) { - return false; - } - } - - return true; -} - -bool EqualitySolver::NotifyClass::eqNotifyTriggerEquality(TNode equality, bool value) { - Debug("bitvector::equality") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl; - if (value) { - return d_solver.storePropagation(equality); - } else { - return d_solver.storePropagation(equality.notNode()); - } -} - -bool EqualitySolver::NotifyClass::eqNotifyTriggerPredicate(TNode predicate, bool value) { - Debug("bitvector::equality") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" ) << ")" << std::endl; - if (value) { - return d_solver.storePropagation(predicate); - } else { - return d_solver.storePropagation(predicate.notNode()); - } -} - -bool EqualitySolver::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) { - Debug("bitvector::equality") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << ")" << std::endl; - if (value) { - return d_solver.storePropagation(t1.eqNode(t2)); - } else { - return d_solver.storePropagation(t1.eqNode(t2).notNode()); - } -} - -void EqualitySolver::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) { - d_solver.conflict(t1, t2); -} - -bool EqualitySolver::storePropagation(TNode literal) { - return d_bv->storePropagation(literal, SUB_EQUALITY); -} - -void EqualitySolver::conflict(TNode a, TNode b) { - std::vector<TNode> assumptions; - d_equalityEngine.explainEquality(a, b, true, assumptions); - d_bv->setConflict(mkAnd(assumptions)); -} - -void EqualitySolver::collectModelInfo(TheoryModel* m) { - if (Debug.isOn("bitvector-model")) { - context::CDList<TNode>::const_iterator it = d_assertions.begin(); - for (; it!= d_assertions.end(); ++it) { - Debug("bitvector-model") << "EqualitySolver::collectModelInfo (assert " - << *it << ")\n"; - } - } - set<Node> termSet; - d_bv->computeRelevantTerms(termSet); - m->assertEqualityEngine(&d_equalityEngine, &termSet); -} diff --git a/src/theory/bv/bv_subtheory_inequality.cpp b/src/theory/bv/bv_subtheory_inequality.cpp new file mode 100644 index 000000000..6b4b1a134 --- /dev/null +++ b/src/theory/bv/bv_subtheory_inequality.cpp @@ -0,0 +1,68 @@ +/********************* */ +/*! \file bv_subtheory_inequality.cpp + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009-2012 New York University and The University of Iowa + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief Algebraic solver. + ** + ** Algebraic solver. + **/ + +#include "theory/bv/bv_subtheory_inequality.h" +#include "theory/bv/theory_bv.h" +#include "theory/bv/theory_bv_utils.h" +#include "theory/model.h" + +using namespace std; +using namespace CVC4; +using namespace CVC4::context; +using namespace CVC4::theory; +using namespace CVC4::theory::bv; +using namespace CVC4::theory::bv::utils; + +bool InequalitySolver::check(Theory::Effort e) { + bool ok = true; + while (!done() && ok) { + TNode fact = get(); + + if (fact.getKind() == kind::NOT && fact[0].getKind() == kind::BITVECTOR_ULE) { + TNode a = fact[0][1]; + TNode b = fact[0][0]; + ok = d_inequalityGraph.addInequality(a, b, true, fact); + } else if (fact.getKind() == kind::NOT && fact[0].getKind() == kind::BITVECTOR_ULT) { + TNode a = fact[0][1]; + TNode b = fact[0][0]; + ok = d_inequalityGraph.addInequality(a, b, false, fact); + } else if (fact.getKind() == kind::BITVECTOR_ULT) { + TNode a = fact[0]; + TNode b = fact[1]; + ok = d_inequalityGraph.addInequality(a, b, true, fact); + } else if (fact.getKind() == kind::BITVECTOR_ULE) { + TNode a = fact[0]; + TNode b = fact[1]; + ok = d_inequalityGraph.addInequality(a, b, false, fact); + } + } + if (!ok) { + std::vector<TNode> conflict; + d_inequalityGraph.getConflict(conflict); + d_bv->setConflict(utils::mkConjunction(conflict)); + return false; + } + return true; +} + +void InequalitySolver::explain(TNode literal, std::vector<TNode>& assumptions) { + Assert (false); +} + +void InequalitySolver::propagate(Theory::Effort e) { + Assert (false); +} + diff --git a/src/theory/bv/bv_subtheory_inequality.h b/src/theory/bv/bv_subtheory_inequality.h new file mode 100644 index 000000000..07c561c84 --- /dev/null +++ b/src/theory/bv/bv_subtheory_inequality.h @@ -0,0 +1,49 @@ +/********************* */ +/*! \file bv_subtheory_inequality.h + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009-2012 New York University and The University of Iowa + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief Algebraic solver. + ** + ** Algebraic solver. + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__BV__BV_SUBTHEORY__INEQUALITY_H +#define __CVC4__THEORY__BV__BV_SUBTHEORY__INEQUALITY_H + +#include "theory/bv/bv_subtheory.h" +#include "theory/bv/bv_inequality_graph.h" + +namespace CVC4 { +namespace theory { +namespace bv { + +class InequalitySolver: public SubtheorySolver { + InequalityGraph d_inequalityGraph; +public: + + InequalitySolver(context::Context* c, TheoryBV* bv) + : SubtheorySolver(c, bv), + d_inequalityGraph(c) + {} + + bool check(Theory::Effort e); + void propagate(Theory::Effort e); + void explain(TNode literal, std::vector<TNode>& assumptions); + bool isInequalityTheory() { return true; } + virtual void collectModelInfo(TheoryModel* m) {} +}; + +} +} +} + +#endif /* __CVC4__THEORY__BV__BV_SUBTHEORY__INEQUALITY_H */ diff --git a/src/theory/bv/kinds b/src/theory/bv/kinds index 2faa12437..052e477ea 100644 --- a/src/theory/bv/kinds +++ b/src/theory/bv/kinds @@ -8,7 +8,7 @@ theory THEORY_BV ::CVC4::theory::bv::TheoryBV "theory/bv/theory_bv.h" typechecker "theory/bv/theory_bv_type_rules.h" properties finite -properties check propagate +properties check propagate presolve rewriter ::CVC4::theory::bv::TheoryBVRewriter "theory/bv/theory_bv_rewriter.h" diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp new file mode 100644 index 000000000..5d376ea50 --- /dev/null +++ b/src/theory/bv/slicer.cpp @@ -0,0 +1,861 @@ +/********************* */ +/*! \file slicer.h + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief Bitvector theory. + ** + ** Bitvector theory. + **/ + +#include "theory/bv/slicer.h" +#include "theory/bv/theory_bv_utils.h" +#include "theory/rewriter.h" +#include "theory/bv/bv_subtheory_core.h" +using namespace CVC4; +using namespace CVC4::theory; +using namespace CVC4::theory::bv; +using namespace std; + + +const TermId CVC4::theory::bv::UndefinedId = -1; + +/** + * Base + * + */ +Base::Base(uint32_t size) + : d_size(size), + d_repr(size/32 + (size % 32 == 0? 0 : 1), 0) +{ + Assert (d_size > 0); +} + + +void Base::sliceAt(Index index) { + if (index == d_size) + return; + Assert(index < d_size); + Index vector_index = index / 32; + Assert (vector_index < d_repr.size()); + Index int_index = index % 32; + uint32_t bit_mask = utils::pow2(int_index); + d_repr[vector_index] = d_repr[vector_index] | bit_mask; +} + +void Base::undoSliceAt(Index index) { + Index vector_index = index / 32; + Assert (vector_index < d_size); + Index int_index = index % 32; + uint32_t bit_mask = utils::pow2(int_index); + d_repr[vector_index] = d_repr[vector_index] ^ bit_mask; +} + +void Base::sliceWith(const Base& other) { + Assert (d_size == other.d_size); + for (unsigned i = 0; i < d_repr.size(); ++i) { + d_repr[i] = d_repr[i] | other.d_repr[i]; + } +} + +bool Base::isCutPoint (Index index) const { + // there is an implicit cut point at the end and begining of the bv + if (index == d_size || index == 0) + return true; + + Index vector_index = index / 32; + Assert (vector_index < d_size); + Index int_index = index % 32; + uint32_t bit_mask = utils::pow2(int_index); + + return (bit_mask & d_repr[vector_index]) != 0; +} + +void Base::diffCutPoints(const Base& other, Base& res) const { + Assert (d_size == other.d_size && res.d_size == d_size); + for (unsigned i = 0; i < d_repr.size(); ++i) { + Assert (res.d_repr[i] == 0); + res.d_repr[i] = d_repr[i] ^ other.d_repr[i]; + } +} + +bool Base::isEmpty() const { + for (unsigned i = 0; i< d_repr.size(); ++i) { + if (d_repr[i] != 0) + return false; + } + return true; +} + +std::string Base::debugPrint() const { + std::ostringstream os; + os << "["; + bool first = true; + for (int i = d_size - 1; i >= 0; --i) { + if (isCutPoint(i)) { + if (first) + first = false; + else + os <<"| "; + + os << i ; + } + } + os << "]"; + return os.str(); +} + +/** + * ExtractTerm + * + */ + +std::string ExtractTerm::debugPrint() const { + ostringstream os; + os << "id" << id << "[" << high << ":" << low <<"] "; + return os.str(); +} + +/** + * NormalForm + * + */ + +std::pair<TermId, Index> NormalForm::getTerm(Index index, const UnionFind& uf) const { + Assert (index < base.getBitwidth()); + Index count = 0; + for (unsigned i = 0; i < decomp.size(); ++i) { + Index size = uf.getBitwidth(decomp[i]); + if ( count + size > index && index >= count) { + return pair<TermId, Index>(decomp[i], count); + } + count += size; + } + Unreachable(); +} + + + +std::string NormalForm::debugPrint(const UnionFind& uf) const { + ostringstream os; + os << "NF " << base.debugPrint() << endl; + os << "("; + for (int i = decomp.size() - 1; i>= 0; --i) { + os << decomp[i] << "[" << uf.getBitwidth(decomp[i]) <<"]"; + os << (i != 0? ", " : ""); + } + os << ") \n"; + return os.str(); +} +/** + * UnionFind::Node + * + */ + +std::string UnionFind::Node::debugPrint() const { + ostringstream os; + os << "Repr " << d_edge.repr << " ["<< d_bitwidth << "] "; + os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl; + return os.str(); +} + + +/** + * UnionFind + * + */ +TermId UnionFind::addNode(Index bitwidth) { + Assert (bitwidth > 0); + Node node(bitwidth); + d_nodes.push_back(node); + + ++(d_statistics.d_numNodes); + + TermId id = d_nodes.size() - 1; + // d_representatives.insert(id); + ++(d_statistics.d_numRepresentatives); + Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl; + return id; +} + +TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) { + if (isExtractTerm(topLevel)) { + ExtractTerm top = getExtractTerm(topLevel); + Index top_high = top.high; + Index top_low = top.low; + Assert (top_high - top_low + 1 > high); + high += top_low; + low += top_low; + topLevel = top.id; + } + ExtractTerm extract(topLevel, high, low); + if (d_extractToId.find(extract) != d_extractToId.end()) { + return d_extractToId[extract]; + } + + Assert (high >= low); + + TermId id = addNode(high - low + 1); + d_idToExtract[id] = extract; + d_extractToId[extract] = id; + return id; +} + +/** + * At this point we assume the slicings of the two terms are properly aligned. + * + * @param t1 + * @param t2 + */ +void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason) { + Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n" + << " " << t2.debugPrint() << "\n" + << " with reason " << reason << endl; + Assert (t1.getBitwidth() == t2.getBitwidth()); + + NormalForm nf1(t1.getBitwidth()); + NormalForm nf2(t2.getBitwidth()); + + getNormalForm(t1, nf1); + getNormalForm(t2, nf2); + + Assert (nf1.decomp.size() == nf2.decomp.size()); + Assert (nf1.base == nf2.base); + + for (unsigned i = 0; i < nf1.decomp.size(); ++i) { + merge (nf1.decomp[i], nf2.decomp[i], reason); + } +} + + +/** + * Merge the two terms in the union find. Both t1 and t2 + * should be root terms. + * + * @param t1 + * @param t2 + */ +void UnionFind::merge(TermId t1, TermId t2, TermId reason) { + Debug("bv-slicer-uf") << "UnionFind::merge (" << t1 <<", " << t2 << ")" << endl; + ++(d_statistics.d_numMerges); + t1 = find(t1); + t2 = find(t2); + + if (t1 == t2) + return; + + Assert (! hasChildren(t1) && ! hasChildren(t2)); + setRepr(t1, t2, reason); + recordOperation(UnionFind::MERGE, t1); + //d_representatives.erase(t1); + d_statistics.d_numRepresentatives += -1; +} + +TermId UnionFind::find(TermId id) { + TermId repr = getRepr(id); + if (repr != UndefinedId) { + TermId find_id = find(repr); + return find_id; + } + return id; +} + +TermId UnionFind::findWithExplanation(TermId id, std::vector<ExplanationId>& explanation) { + TermId repr = getRepr(id); + + if (repr != UndefinedId) { + TermId reason = getReason(id); + Assert (reason != UndefinedId); + explanation.push_back(reason); + + TermId find_id = findWithExplanation(repr, explanation); + return find_id; + } + return id; +} + + +/** + * Splits the representative of the term between i-1 and i + * + * @param id the id of the term + * @param i the index we are splitting at + * + * @return + */ +void UnionFind::split(TermId id, Index i) { + Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl; + id = find(id); + Debug("bv-slicer-uf") << " node: " << d_nodes[id].debugPrint() << endl; + + if (i == 0 || i == getBitwidth(id)) { + // nothing to do + return; + } + + Assert (i < getBitwidth(id)); + if (!hasChildren(id)) { + // first time we split this term + TermId bottom_id = addExtract(id, i - 1, 0); + TermId top_id = addExtract(id, getBitwidth(id) - 1, i); + setChildren(id, top_id, bottom_id); + recordOperation(UnionFind::SPLIT, id); + + if (d_slicer->termInEqualityEngine(id)) { + d_slicer->enqueueSplit(id, i, top_id, bottom_id); + } + } else { + Index cut = getCutPoint(id); + if (i < cut ) + split(getChild(id, 0), i); + else + split(getChild(id, 1), i - cut); + } + ++(d_statistics.d_numSplits); +} + +// TermId UnionFind::getTopLevel(TermId id) const { +// __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> >::const_iterator it = d_idToExtract.find(id); +// if (it != d_idToExtract.end()) { +// return (*it).second.id; +// } +// return id; +// } + +void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) { + nf.clear(); + getDecomposition(term, nf.decomp); + // update nf base + Index count = 0; + for (unsigned i = 0; i < nf.decomp.size(); ++i) { + count += getBitwidth(nf.decomp[i]); + nf.base.sliceAt(count); + } + Debug("bv-slicer-uf") << "UnionFind::getNormalFrom term: " << term.debugPrint() << endl; + Debug("bv-slicer-uf") << " nf: " << nf.debugPrint(*this) << endl; +} + +void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp) { + // making sure the term is aligned + TermId id = find(term.id); + + Assert (term.high < getBitwidth(id)); + // because we split the node, this must be the whole extract + if (!hasChildren(id)) { + Assert (term.high == getBitwidth(id) - 1 && + term.low == 0); + decomp.push_back(id); + return; + } + + Index cut = getCutPoint(id); + + if (term.low < cut && term.high < cut) { + // the extract falls entirely on the low child + ExtractTerm child_ex(getChild(id, 0), term.high, term.low); + getDecomposition(child_ex, decomp); + } + else if (term.low >= cut && term.high >= cut){ + // the extract falls entirely on the high child + ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut); + getDecomposition(child_ex, decomp); + } + else { + // the extract is split over the two children + ExtractTerm low_child(getChild(id, 0), cut - 1, term.low); + getDecomposition(low_child, decomp); + ExtractTerm high_child(getChild(id, 1), term.high - cut, 0); + getDecomposition(high_child, decomp); + } +} + +void UnionFind::getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, + std::vector<ExplanationId>& explanation) { + nf.clear(); + getDecompositionWithExplanation(term, nf.decomp, explanation); + // update nf base + Index count = 0; + for (unsigned i = 0; i < nf.decomp.size(); ++i) { + count += getBitwidth(nf.decomp[i]); + nf.base.sliceAt(count); + } + Debug("bv-slicer-uf") << "UnionFind::getNormalFrom term: " << term.debugPrint() << endl; + Debug("bv-slicer-uf") << " nf: " << nf.debugPrint(*this) << endl; +} + +void UnionFind::getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, + std::vector<ExplanationId>& explanation) { + // making sure the term is aligned + TermId id = findWithExplanation(term.id, explanation); + + Assert (term.high < getBitwidth(id)); + // because we split the node, this must be the whole extract + if (!hasChildren(id)) { + Assert (term.high == getBitwidth(id) - 1 && + term.low == 0); + decomp.push_back(id); + return; + } + + Index cut = getCutPoint(id); + + if (term.low < cut && term.high < cut) { + // the extract falls entirely on the low child + ExtractTerm child_ex(getChild(id, 0), term.high, term.low); + getDecompositionWithExplanation(child_ex, decomp, explanation); + } + else if (term.low >= cut && term.high >= cut){ + // the extract falls entirely on the high child + ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut); + getDecompositionWithExplanation(child_ex, decomp, explanation); + } + else { + // the extract is split over the two children + ExtractTerm low_child(getChild(id, 0), cut - 1, term.low); + getDecompositionWithExplanation(low_child, decomp, explanation); + ExtractTerm high_child(getChild(id, 1), term.high - cut, 0); + getDecompositionWithExplanation(high_child, decomp, explanation); + } +} + +/** + * May cause reslicings of the decompositions. Must not assume the decompositons + * are the current normal form. + * + * @param d1 + * @param d2 + * @param common + */ +void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposition& decomp2, TermId common) { + Debug("bv-slicer") << "UnionFind::handleCommonSlice common = " << common << endl; + Index common_size = getBitwidth(common); + // find starting points of common slice + Index start1 = 0; + for (unsigned j = 0; j < decomp1.size(); ++j) { + if (decomp1[j] == common) + break; + start1 += getBitwidth(decomp1[j]); + } + + Index start2 = 0; + for (unsigned j = 0; j < decomp2.size(); ++j) { + if (decomp2[j] == common) + break; + start2 += getBitwidth(decomp2[j]); + } + if (start1 > start2) { + Index temp = start1; + start1 = start2; + start2 = temp; + } + + if (start2 - start1 < common_size) { + Index overlap = start1 + common_size - start2; + Assert (overlap > 0); + Index diff = common_size - overlap; + Assert (diff >= 0); + Index granularity = utils::gcd(diff, overlap); + // split the common part + for (unsigned i = 0; i < common_size; i+= granularity) { + split(common, i); + } + } + +} + +void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) { + Debug("bv-slicer") << "UnionFind::alignSlicings " << term1.debugPrint() << endl; + Debug("bv-slicer") << " " << term2.debugPrint() << endl; + NormalForm nf1(term1.getBitwidth()); + NormalForm nf2(term2.getBitwidth()); + + getNormalForm(term1, nf1); + getNormalForm(term2, nf2); + + Assert (nf1.base.getBitwidth() == nf2.base.getBitwidth()); + + // first check if the two have any common slices + std::vector<TermId> intersection; + utils::intersect(nf1.decomp, nf2.decomp, intersection); + for (unsigned i = 0; i < intersection.size(); ++i) { + // handle common slice may change the normal form + handleCommonSlice(nf1.decomp, nf2.decomp, intersection[i]); + } + // propagate cuts to a fixpoint + bool changed; + Base cuts(term1.getBitwidth()); + do { + changed = false; + // we need to update the normal form which may have changed + getNormalForm(term1, nf1); + getNormalForm(term2, nf2); + + // align the cuts points of the two slicings + // FIXME: this can be done more efficiently + cuts.sliceWith(nf1.base); + cuts.sliceWith(nf2.base); + + for (unsigned i = 0; i < cuts.getBitwidth(); ++i) { + if (cuts.isCutPoint(i)) { + if (!nf1.base.isCutPoint(i)) { + pair<TermId, Index> pair1 = nf1.getTerm(i, *this); + split(pair1.first, i - pair1.second); + changed = true; + } + if (!nf2.base.isCutPoint(i)) { + pair<TermId, Index> pair2 = nf2.getTerm(i, *this); + split(pair2.first, i - pair2.second); + changed = true; + } + } + } + } while (changed); +} +/** + * Given an extract term a[i:j] makes sure a is sliced + * at indices i and j. + * + * @param term + */ +void UnionFind::ensureSlicing(const ExtractTerm& term) { + //Debug("bv-slicer") << "Slicer::ensureSlicing " << term.debugPrint() << endl; + TermId id = find(term.id); + split(id, term.high + 1); + split(id, term.low); +} + +void UnionFind::backtrack() { + int size = d_undoStack.size(); + for (int i = size; i > (int)d_undoStackIndex.get(); --i) { + Operation op = d_undoStack.back(); + Assert (!d_undoStack.empty()); + d_undoStack.pop_back(); + if (op.op == UnionFind::MERGE) { + undoMerge(op.id); + } else { + Assert (op.op == UnionFind::SPLIT); + undoSplit(op.id); + } + } +} + +void UnionFind::undoMerge(TermId id) { + TermId repr = getRepr(id); + Assert (repr != UndefinedId); + setRepr(id, UndefinedId, UndefinedId); +} + +void UnionFind::undoSplit(TermId id) { + Assert (hasChildren(id)); + setChildren(id, UndefinedId, UndefinedId); +} + +void UnionFind::recordOperation(OperationKind op, TermId term) { + d_undoStackIndex.set(d_undoStackIndex.get() + 1); + d_undoStack.push_back(Operation(op, term)); + Assert (d_undoStack.size() == d_undoStackIndex); +} + +void UnionFind::getBase(TermId id, Base& base, Index offset) { + id = find(id); + if (!hasChildren(id)) + return; + TermId id1 = find(getChild(id, 1)); + TermId id0 = find(getChild(id, 0)); + Index cut = getCutPoint(id); + base.sliceAt(cut + offset); + getBase(id1, base, cut + offset); + getBase(id0, base, offset); +} + + +/** + * Slicer + * + */ + +ExtractTerm Slicer::registerTerm(TNode node) { + Index low = 0, high = utils::getSize(node) - 1; + TNode n = node; + if (node.getKind() == kind::BITVECTOR_EXTRACT) { + n = node[0]; + high = utils::getExtractHigh(node); + low = utils::getExtractLow(node); + } + if (d_nodeToId.find(n) == d_nodeToId.end()) { + TermId id = d_unionFind.addNode(utils::getSize(n)); + d_nodeToId[n] = id; + d_idToNode[id] = n; + } + TermId id = d_nodeToId[n]; + d_unionFind.addExtract(id, high, low); + ExtractTerm res(id, high, low); + Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl; + return res; +} + +void Slicer::processEquality(TNode eq) { + Debug("bv-slicer") << "Slicer::processEquality: " << eq << endl; + + registerEquality(eq); + Assert (eq.getKind() == kind::EQUAL); + TNode a = eq[0]; + TNode b = eq[1]; + ExtractTerm a_ex= registerTerm(a); + ExtractTerm b_ex= registerTerm(b); + + d_unionFind.ensureSlicing(a_ex); + d_unionFind.ensureSlicing(b_ex); + + d_unionFind.alignSlicings(a_ex, b_ex); + + Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl; + Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl; + Debug("bv-slicer") << "Slicer::processEquality done. " << endl; +} + +void Slicer::assertEquality(TNode eq) { + Assert (eq.getKind() == kind::EQUAL); + ExtractTerm a = registerTerm(eq[0]); + ExtractTerm b = registerTerm(eq[1]); + ExplanationId reason = getExplanationId(eq); + d_unionFind.unionTerms(a, b, reason); +} + +TermId Slicer::getId(TNode node) const { + __gnu_cxx::hash_map<Node, TermId, NodeHashFunction >::const_iterator it = d_nodeToId.find(node); + Assert (it != d_nodeToId.end()); + return it->second; +} + +void Slicer::registerEquality(TNode eq) { + if (d_explanationToId.find(eq) == d_explanationToId.end()) { + ExplanationId id = d_explanations.size(); + d_explanations.push_back(eq); + d_explanationToId[eq] = id; + } +} + +void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation) { + Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl; + + Index high = utils::getSize(node) - 1; + Index low = 0; + TNode top = node; + if (node.getKind() == kind::BITVECTOR_EXTRACT) { + high = utils::getExtractHigh(node); + low = utils::getExtractLow(node); + top = node[0]; + } + + Assert (d_nodeToId.find(top) != d_nodeToId.end()); + TermId id = d_nodeToId[top]; + NormalForm nf(high-low+1); + std::vector<ExplanationId> explanation_ids; + d_unionFind.getNormalFormWithExplanation(ExtractTerm(id, high, low), nf, explanation_ids); + + for (unsigned i = 0; i < explanation_ids.size(); ++i) { + Assert (hasExplanation(explanation_ids[i])); + TNode exp = getExplanation(explanation_ids[i]); + explanation.push_back(exp); + } + + for (int i = nf.decomp.size() - 1; i>=0 ; --i) { + Node current = getNode(nf.decomp[i]); + decomp.push_back(current); + } + + + Debug("bv-slicer") << "as ["; + for (unsigned i = 0; i < decomp.size(); ++i) { + Debug("bv-slicer") << decomp[i] <<" "; + } + Debug("bv-slicer") << "]" << endl; + +} + +bool Slicer::isCoreTerm(TNode node) { + if (d_coreTermCache.find(node) == d_coreTermCache.end()) { + Kind kind = node.getKind(); + if (kind != kind::BITVECTOR_EXTRACT && + kind != kind::BITVECTOR_CONCAT && + kind != kind::EQUAL && kind != kind::NOT && + node.getMetaKind() != kind::metakind::VARIABLE && + kind != kind::CONST_BITVECTOR) { + d_coreTermCache[node] = false; + return false; + } else { + // we need to recursively check whether the term is a root term or not + bool isCore = true; + for (unsigned i = 0; i < node.getNumChildren(); ++i) { + isCore = isCore && isCoreTerm(node[i]); + } + d_coreTermCache[node] = isCore; + return isCore; + } + } + return d_coreTermCache[node]; +} +unsigned Slicer::d_numAddedEqualities = 0; + +void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) { + Assert (node.getKind() == kind::EQUAL); + TNode t1 = node[0]; + TNode t2 = node[1]; + + uint32_t width = utils::getSize(t1); + + Base base1(width); + if (t1.getKind() == kind::BITVECTOR_CONCAT) { + int size = 0; + // no need to count the last child since the end cut point is implicit + for (int i = t1.getNumChildren() - 1; i >= 1 ; --i) { + size = size + utils::getSize(t1[i]); + base1.sliceAt(size); + } + } + + Base base2(width); + if (t2.getKind() == kind::BITVECTOR_CONCAT) { + unsigned size = 0; + for (int i = t2.getNumChildren() - 1; i >= 1; --i) { + size = size + utils::getSize(t2[i]); + base2.sliceAt(size); + } + } + + base1.sliceWith(base2); + if (!base1.isEmpty()) { + // we split the equalities according to the base + int last = 0; + for (unsigned i = 1; i <= utils::getSize(t1); ++i) { + if (base1.isCutPoint(i)) { + Node extract1 = utils::mkExtract(t1, i-1, last); + Node extract2 = utils::mkExtract(t2, i-1, last); + last = i; + Assert (utils::getSize(extract1) == utils::getSize(extract2)); + equalities.push_back(utils::mkNode(kind::EQUAL, extract1, extract2)); + } + } + } else { + // just return same equality + equalities.push_back(node); + } + d_numAddedEqualities += equalities.size() - 1; +} + + +ExtractTerm UnionFind::getExtractTerm(TermId id) const { + Assert (isExtractTerm(id)); + + return (d_idToExtract.find(id))->second; +} + +bool UnionFind::isExtractTerm(TermId id) const { + return d_idToExtract.find(id) != d_idToExtract.end(); +} + +bool Slicer::hasNode(TermId id) const { + return d_idToNode.find(id) != d_idToNode.end(); +} + +Node Slicer::getNode(TermId id) const { + if (hasNode(id)) { + return d_idToNode.find(id)->second; + } + // otherwise must be an extract + Assert (d_unionFind.isExtractTerm(id)); + ExtractTerm extract = d_unionFind.getExtractTerm(id); + Assert (hasNode(extract.id)); + TNode node = d_idToNode.find(extract.id)->second; + Node ex = utils::mkExtract(node, extract.high, extract.low); + return ex; +} + +bool Slicer::termInEqualityEngine(TermId id) { + Node node = getNode(id); + return d_coreSolver->hasTerm(node); +} + +void Slicer::enqueueSplit(TermId id, Index i, TermId top_id, TermId bottom_id) { + Node node = getNode(id); + Node bottom = Rewriter::rewrite(utils::mkExtract(node, i -1 , 0)); + Node top = Rewriter::rewrite(utils::mkExtract(node, utils::getSize(node) - 1, i)); + // must add terms to equality engine so we get notified when they get split more + d_coreSolver->addTermToEqualityEngine(bottom); + d_coreSolver->addTermToEqualityEngine(top); + + Node eq = utils::mkNode(kind::EQUAL, node, utils::mkConcat(top, bottom)); + d_newSplits.push_back(eq); + Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl; + Debug("bv-slicer") << " " << id << "=" << top_id << " " << bottom_id << endl; +} + +void Slicer::getNewSplits(std::vector<Node>& splits) { + for (unsigned i = d_newSplitsIndex; i < d_newSplits.size(); ++i) { + splits.push_back(d_newSplits[i]); + } + d_newSplitsIndex = d_newSplits.size(); +} + +bool Slicer::hasExplanation(ExplanationId id) const { + return id < d_explanations.size(); +} + +TNode Slicer::getExplanation(ExplanationId id) const { + Assert(hasExplanation(id)); + return d_explanations[id]; +} + +ExplanationId Slicer::getExplanationId(TNode reason) const { + Assert (d_explanationToId.find(reason) != d_explanationToId.end()); + return d_explanationToId.find(reason)->second; +} + +std::string UnionFind::debugPrint(TermId id) { + ostringstream os; + if (hasChildren(id)) { + TermId id1 = find(getChild(id, 1)); + TermId id0 = find(getChild(id, 0)); + os << debugPrint(id1); + os << debugPrint(id0); + } else { + if (getRepr(id) == UndefinedId) { + os <<"id"<< id <<"[" << getBitwidth(id) <<"] "; + } else { + os << debugPrint(find(id)); + } + } + return os.str(); +} + +UnionFind::Statistics::Statistics(): + d_numNodes("theory::bv::slicer::NumberOfNodes", 0), + d_numRepresentatives("theory::bv::slicer::NumberOfRepresentatives", 0), + d_numSplits("theory::bv::slicer::NumberOfSplits", 0), + d_numMerges("theory::bv::slicer::NumberOfMerges", 0), + d_avgFindDepth("theory::bv::slicer::AverageFindDepth"), + d_numAddedEqualities("theory::bv::slicer::NumberOfEqualitiesAdded", Slicer::d_numAddedEqualities) +{ + StatisticsRegistry::registerStat(&d_numRepresentatives); + StatisticsRegistry::registerStat(&d_numSplits); + StatisticsRegistry::registerStat(&d_numMerges); + StatisticsRegistry::registerStat(&d_avgFindDepth); + StatisticsRegistry::registerStat(&d_numAddedEqualities); +} + +UnionFind::Statistics::~Statistics() { + StatisticsRegistry::unregisterStat(&d_numRepresentatives); + StatisticsRegistry::unregisterStat(&d_numSplits); + StatisticsRegistry::unregisterStat(&d_numMerges); + StatisticsRegistry::unregisterStat(&d_avgFindDepth); + StatisticsRegistry::unregisterStat(&d_numAddedEqualities); +} diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h new file mode 100644 index 000000000..ab2d5e88f --- /dev/null +++ b/src/theory/bv/slicer.h @@ -0,0 +1,368 @@ +/********************* */ +/*! \file slicer.h + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief Bitvector theory. + ** + ** Bitvector theory. + **/ + +#include "cvc4_private.h" + + +#include <vector> +#include <list> +#include <ext/hash_map> +#include <math.h> + +#include "util/bitvector.h" +#include "util/statistics_registry.h" +#include "util/index.h" +#include "expr/node.h" +#include "theory/bv/theory_bv_utils.h" +#include "context/context.h" +#include "context/cdhashset.h" +#include "context/cdo.h" +#include "context/cdqueue.h" +#ifndef __CVC4__THEORY__BV__SLICER_BV_H +#define __CVC4__THEORY__BV__SLICER_BV_H + + +namespace CVC4 { + +namespace theory { +namespace bv { + + + +typedef Index TermId; +typedef TermId ExplanationId; +extern const TermId UndefinedId; + +class CDBase; + +/** + * Base + * + */ +class Base { + Index d_size; + std::vector<uint32_t> d_repr; + void undoSliceAt(Index index); +public: + Base (Index size); + void sliceAt(Index index); + + void sliceWith(const Base& other); + bool isCutPoint(Index index) const; + void diffCutPoints(const Base& other, Base& res) const; + bool isEmpty() const; + std::string debugPrint() const; + Index getBitwidth() const { return d_size; } + void clear() { + for (unsigned i = 0; i < d_repr.size(); ++i) { + d_repr[i] = 0; + } + } + bool operator==(const Base& other) const { + if (other.getBitwidth() != getBitwidth()) + return false; + for (unsigned i = 0; i < d_repr.size(); ++i) { + if (d_repr[i] != other.d_repr[i]) + return false; + } + return true; + } +}; + + +/** + * UnionFind + * + */ +typedef context::CDHashSet<uint32_t, std::hash<uint32_t> > CDTermSet; +typedef std::vector<TermId> Decomposition; + +struct ExtractTerm { + TermId id; + Index high; + Index low; + ExtractTerm() + : id (UndefinedId), + high(UndefinedId), + low(UndefinedId) + {} + ExtractTerm(TermId i, Index h, Index l) + : id (i), + high(h), + low(l) + { + Assert (h >= l && id != UndefinedId); + } + bool operator== (const ExtractTerm& other) const { + return id == other.id && high == other.high && low == other.low; + } + Index getBitwidth() const { return high - low + 1; } + std::string debugPrint() const; + friend class ExtractTermHashFunction; +}; + +struct ExtractTermHashFunction { + ::std::size_t operator() (const ExtractTerm& t) const { + __gnu_cxx::hash<unsigned> h; + unsigned id = t.id; + unsigned high = t.high; + unsigned low = t.low; + return (h(id) * 7919 + h(high))* 4391 + h(low); + } +}; + +class UnionFind; + +struct NormalForm { + Base base; + Decomposition decomp; + + NormalForm(Index bitwidth) + : base(bitwidth), + decomp() + {} + /** + * Returns the term in the decomposition on which the index i + * falls in + * @param i + * + * @return + */ + std::pair<TermId, Index> getTerm(Index i, const UnionFind& uf) const; + std::string debugPrint(const UnionFind& uf) const; + void clear() { base.clear(); decomp.clear(); } +}; + +class Slicer; + +class UnionFind : public context::ContextNotifyObj { + + struct ReprEdge { + TermId repr; + ExplanationId reason; + ReprEdge() + : repr(UndefinedId), + reason(UndefinedId) + {} + }; + + class Node { + Index d_bitwidth; + TermId d_ch1, d_ch0; // the ids of the two children if they exist + ReprEdge d_edge; // points to the representative and stores the explanation + + public: + Node(Index b) + : d_bitwidth(b), + d_ch1(UndefinedId), + d_ch0(UndefinedId), + d_edge() + {} + + TermId getRepr() const { return d_edge.repr; } + ExplanationId getReason() const { return d_edge.reason; } + Index getBitwidth() const { return d_bitwidth; } + bool hasChildren() const { return d_ch1 != UndefinedId && d_ch0 != UndefinedId; } + + TermId getChild(Index i) const { + Assert (i < 2); + return i == 0? d_ch0 : d_ch1; + } + void setRepr(TermId repr, ExplanationId reason) { + Assert (! hasChildren()); + d_edge.repr = repr; + d_edge.reason = reason; + } + void setChildren(TermId ch1, TermId ch0) { + // Assert (d_repr == UndefinedId && !hasChildren()); + d_ch1 = ch1; + d_ch0 = ch0; + } + std::string debugPrint() const; + }; + + /// map from TermId to the nodes that represent them + std::vector<Node> d_nodes; + __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> > d_idToExtract; + __gnu_cxx::hash_map<ExtractTerm, TermId, ExtractTermHashFunction > d_extractToId; + + void getDecomposition(const ExtractTerm& term, Decomposition& decomp); + void getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, std::vector<ExplanationId>& explanation); + void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common); + /// getter methods for the internal nodes + TermId getRepr(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getRepr(); + } + ExplanationId getReason(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getReason(); + } + TermId getChild(TermId id, Index i) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getChild(i); + } + Index getCutPoint(TermId id) const { + return getBitwidth(getChild(id, 0)); + } + bool hasChildren(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].hasChildren(); + } + // TermId getTopLevel(TermId id) const; + + /// setter methods for the internal nodes + void setRepr(TermId id, TermId new_repr, ExplanationId reason) { + Assert (id < d_nodes.size()); + d_nodes[id].setRepr(new_repr, reason); + } + void setChildren(TermId id, TermId ch1, TermId ch0) { + Assert ((ch1 == UndefinedId && ch0 == UndefinedId) || + (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0))); + d_nodes[id].setChildren(ch1, ch0); + } + + /* Backtracking mechanisms */ + + enum OperationKind { + MERGE, + SPLIT + }; + + struct Operation { + OperationKind op; + TermId id; + Operation(OperationKind o, TermId i) + : op(o), id(i) {} + }; + + std::vector<Operation> d_undoStack; + context::CDO<unsigned> d_undoStackIndex; + + void backtrack(); + void undoMerge(TermId id); + void undoSplit(TermId id); + void recordOperation(OperationKind op, TermId term); + virtual ~UnionFind() throw(AssertionException) {} + class Statistics { + public: + IntStat d_numNodes; + IntStat d_numRepresentatives; + IntStat d_numSplits; + IntStat d_numMerges; + AverageStat d_avgFindDepth; + ReferenceStat<unsigned> d_numAddedEqualities; + Statistics(); + ~Statistics(); + }; + Statistics d_statistics; + Slicer* d_slicer; +public: + UnionFind(context::Context* ctx, Slicer* slicer) + : ContextNotifyObj(ctx), + d_nodes(), + d_idToExtract(), + d_extractToId(), + d_undoStack(), + d_undoStackIndex(ctx), + d_statistics(), + d_slicer(slicer) + {} + + TermId addNode(Index bitwidth); + TermId addExtract(Index topLevel, Index high, Index low); + ExtractTerm getExtractTerm(TermId id) const; + bool isExtractTerm(TermId id) const; + + void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason); + void merge(TermId t1, TermId t2, TermId reason); + TermId find(TermId t1); + TermId findWithExplanation(TermId id, std::vector<ExplanationId>& explanation); + void split(TermId term, Index i); + void getNormalForm(const ExtractTerm& term, NormalForm& nf); + void getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, std::vector<ExplanationId>& explanation); + void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2); + void ensureSlicing(const ExtractTerm& term); + Index getBitwidth(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getBitwidth(); + } + void getBase(TermId id, Base& base, Index offset); + std::string debugPrint(TermId id); + + void contextNotifyPop() { + backtrack(); + } + friend class Slicer; +}; + +class CoreSolver; + +class Slicer { + __gnu_cxx::hash_map<TermId, TNode, __gnu_cxx::hash<TermId> > d_idToNode; + __gnu_cxx::hash_map<Node, TermId, NodeHashFunction> d_nodeToId; + __gnu_cxx::hash_map<Node, bool, NodeHashFunction> d_coreTermCache; + __gnu_cxx::hash_map<Node, ExplanationId, NodeHashFunction> d_explanationToId; + std::vector<Node> d_explanations; + UnionFind d_unionFind; + + context::CDQueue<Node> d_newSplits; + context::CDO<unsigned> d_newSplitsIndex; + CoreSolver* d_coreSolver; + TermId d_termIdCount; +public: + Slicer(context::Context* ctx, CoreSolver* coreSolver) + : d_idToNode(), + d_nodeToId(), + d_coreTermCache(), + d_explanationToId(), + d_explanations(), + d_unionFind(ctx, this), + d_newSplits(ctx), + d_newSplitsIndex(ctx, 0), + d_coreSolver(coreSolver) + {} + + void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation); + void registerEquality(TNode eq); + ExtractTerm registerTerm(TNode node); + void processEquality(TNode eq); + void assertEquality(TNode eq); + bool isCoreTerm (TNode node); + + bool hasNode(TermId id) const; + Node getNode(TermId id) const; + TermId getId(TNode node) const; + + bool hasExplanation(ExplanationId id) const; + TNode getExplanation(ExplanationId id) const; + ExplanationId getExplanationId(TNode reason) const; + + bool termInEqualityEngine(TermId id); + void enqueueSplit(TermId id, Index i, TermId top, TermId bottom); + void getNewSplits(std::vector<Node>& splits); + static void splitEqualities(TNode node, std::vector<Node>& equalities); + static unsigned d_numAddedEqualities; +}; + + +}/* CVC4::theory::bv namespace */ +}/* CVC4::theory namespace */ +}/* CVC4 namespace */ + +#endif /* __CVC4__THEORY__BV__SLICER_BV_H */ diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 57a77c0d2..a794d63a3 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -17,6 +17,7 @@ #include "theory/bv/theory_bv.h" #include "theory/bv/theory_bv_utils.h" +#include "theory/bv/slicer.h" #include "theory/valuation.h" #include "theory/bv/bitblaster.h" #include "theory/bv/options.h" @@ -39,8 +40,9 @@ TheoryBV::TheoryBV(context::Context* c, context::UserContext* u, OutputChannel& d_context(c), d_alreadyPropagatedSet(c), d_sharedTermsSet(c), + d_coreSolver(c, this), + d_inequalitySolver(c, this), d_bitblastSolver(c, this), - d_equalitySolver(c, this), d_statistics(), d_conflict(c, false), d_literalsToPropagate(c), @@ -52,7 +54,7 @@ TheoryBV::~TheoryBV() {} void TheoryBV::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalitySolver.setMasterEqualityEngine(eq); + d_coreSolver.setMasterEqualityEngine(eq); } TheoryBV::Statistics::Statistics(): @@ -71,6 +73,8 @@ TheoryBV::Statistics::~Statistics() { StatisticsRegistry::unregisterStat(&d_solveTimer); } + + void TheoryBV::preRegisterTerm(TNode node) { Debug("bitvector-preregister") << "TheoryBV::preRegister(" << node << ")" << std::endl; @@ -78,9 +82,9 @@ void TheoryBV::preRegisterTerm(TNode node) { // don't use the equality engine in the eager bit-blasting return; } - + d_bitblastSolver.preRegister(node); - d_equalitySolver.preRegister(node); + d_coreSolver.preRegister(node); } void TheoryBV::sendConflict() { @@ -105,25 +109,30 @@ void TheoryBV::check(Effort e) return; } - // getting the new assertions - std::vector<TNode> new_assertions; while (!done()) { - Assertion assertion = get(); - TNode fact = assertion.assertion; - new_assertions.push_back(fact); - Debug("bitvector-assertions") << "TheoryBV::check assertion " << fact << "\n"; + TNode fact = get().assertion; + d_coreSolver.assertFact(fact); + d_inequalitySolver.assertFact(fact); + d_bitblastSolver.assertFact(fact); } + bool ok = true; if (!inConflict()) { - // sending assertions to the equality solver first - d_equalitySolver.addAssertions(new_assertions, e); + ok = d_coreSolver.check(e); } + Assert (!ok == inConflict()); - if (!inConflict()) { - // sending assertions to the bitblast solver - d_bitblastSolver.addAssertions(new_assertions, e); - } + // if (!inConflict() && !d_coreSolver.isCoreTheory()) { + // ok = d_inequalitySolver.check(e); + // } + Assert (!ok == inConflict()); + if (!inConflict() && !d_coreSolver.isCoreTheory()) { + // if (!inConflict() && !d_inequalitySolver.isInequalityTheory()) { + ok = d_bitblastSolver.check(e); + } + + Assert (!ok == inConflict()); if (inConflict()) { sendConflict(); } @@ -132,9 +141,8 @@ void TheoryBV::check(Effort e) void TheoryBV::collectModelInfo( TheoryModel* m, bool fullModel ){ Assert(!inConflict()); // Assert (fullModel); // can only query full model - d_equalitySolver.collectModelInfo(m); + d_coreSolver.collectModelInfo(m); d_bitblastSolver.collectModelInfo(m); - } void TheoryBV::propagate(Effort e) { @@ -187,16 +195,25 @@ Theory::PPAssertStatus TheoryBV::ppAssert(TNode in, SubstitutionMap& outSubstitu return PP_ASSERT_STATUS_UNSOLVED; } - Node TheoryBV::ppRewrite(TNode t) { if (RewriteRule<BitwiseEq>::applies(t)) { Node result = RewriteRule<BitwiseEq>::run<false>(t); return Rewriter::rewrite(result); } + + if (t.getKind() == kind::EQUAL) { + std::vector<Node> equalities; + Slicer::splitEqualities(t, equalities); + return utils::mkAnd(equalities); + } + return t; } +void TheoryBV::presolve() { + Debug("bitvector") << "TheoryBV::presolve" << endl; +} bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory) { @@ -227,7 +244,7 @@ bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory) // * bitblaster needs to be left alone until it's done, otherwise it doesn't know how to explain // * equality engine can propagate eagerly bool ok = true; - if (subtheory == SUB_EQUALITY) { + if (subtheory == SUB_CORE) { d_out->propagate(literal); if (!ok) { setConflict(); @@ -242,9 +259,9 @@ bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory) void TheoryBV::explain(TNode literal, std::vector<TNode>& assumptions) { // Ask the appropriate subtheory for the explanation - if (propagatedBy(literal, SUB_EQUALITY)) { - Debug("bitvector::explain") << "TheoryBV::explain(" << literal << "): EQUALITY" << std::endl; - d_equalitySolver.explain(literal, assumptions); + if (propagatedBy(literal, SUB_CORE)) { + Debug("bitvector::explain") << "TheoryBV::explain(" << literal << "): CORE" << std::endl; + d_coreSolver.explain(literal, assumptions); } else { Assert(propagatedBy(literal, SUB_BITBLAST)); Debug("bitvector::explain") << "TheoryBV::explain(" << literal << ") : BITBLASTER" << std::endl; @@ -274,7 +291,7 @@ void TheoryBV::addSharedTerm(TNode t) { Debug("bitvector::sharing") << indent() << "TheoryBV::addSharedTerm(" << t << ")" << std::endl; d_sharedTermsSet.insert(t); if (!options::bitvectorEagerBitblast() && d_useEqualityEngine) { - d_equalitySolver.addSharedTerm(t); + d_coreSolver.addSharedTerm(t); } } @@ -285,7 +302,7 @@ EqualityStatus TheoryBV::getEqualityStatus(TNode a, TNode b) return EQUALITY_UNKNOWN; } - EqualityStatus status = d_equalitySolver.getEqualityStatus(a, b); + EqualityStatus status = d_coreSolver.getEqualityStatus(a, b); if (status == EQUALITY_UNKNOWN) { status = d_bitblastSolver.getEqualityStatus(a, b); } diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index e38f3568c..13a475d3d 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -25,10 +25,11 @@ #include "context/cdhashset.h" #include "theory/bv/theory_bv_utils.h" #include "util/statistics_registry.h" -#include "context/cdqueue.h" #include "theory/bv/bv_subtheory.h" -#include "theory/bv/bv_subtheory_eq.h" +#include "theory/bv/bv_subtheory_core.h" #include "theory/bv/bv_subtheory_bitblast.h" +#include "theory/bv/bv_subtheory_inequality.h" +#include "theory/bv/slicer.h" namespace CVC4 { namespace theory { @@ -42,9 +43,10 @@ class TheoryBV : public Theory { /** Context dependent set of atoms we already propagated */ context::CDHashSet<Node, NodeHashFunction> d_alreadyPropagatedSet; context::CDHashSet<Node, NodeHashFunction> d_sharedTermsSet; - - BitblastSolver d_bitblastSolver; - EqualitySolver d_equalitySolver; + + CoreSolver d_coreSolver; + InequalitySolver d_inequalitySolver; + BitblastSolver d_bitblastSolver; public: TheoryBV(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo, QuantifiersEngine* qe); @@ -67,6 +69,7 @@ public: PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions); Node ppRewrite(TNode t); + void presolve(); private: class Statistics { @@ -137,7 +140,8 @@ private: friend class Bitblaster; friend class BitblastSolver; friend class EqualitySolver; - + friend class CoreSolver; + friend class InequalitySolver; };/* class TheoryBV */ }/* CVC4::theory::bv namespace */ diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h index 7d851d0fb..e5a7bbb84 100644 --- a/src/theory/bv/theory_bv_utils.h +++ b/src/theory/bv/theory_bv_utils.h @@ -31,6 +31,12 @@ namespace theory { namespace bv { namespace utils { +inline uint32_t pow2(uint32_t power) { + Assert (power < 32); + uint32_t one = 1; + return one << power; +} + inline unsigned getExtractHigh(TNode node) { return node.getOperator().getConst<BitVectorExtract>().high; } @@ -67,11 +73,11 @@ inline Node mkAnd(std::vector<TNode>& children) { std::set<TNode> distinctChildren; distinctChildren.insert(children.begin(), children.end()); - if (children.size() == 0) { + if (distinctChildren.size() == 0) { return mkTrue(); } - if (children.size() == 1) { + if (distinctChildren.size() == 1) { return *children.begin(); } @@ -406,6 +412,35 @@ inline std::string vectorToString(const std::vector<Node>& nodes) { return out.str(); } +// FIXME: dumb code +inline void intersect(const std::vector<uint32_t>& v1, + const std::vector<uint32_t>& v2, + std::vector<uint32_t>& intersection) { + for (unsigned i = 0; i < v1.size(); ++i) { + bool found = false; + for (unsigned j = 0; j < v2.size(); ++j) { + if (v2[j] == v1[i]) { + found = true; + break; + } + } + if (found) { + intersection.push_back(v1[i]); + } + } +} + +template <class T> +inline T gcd(T a, T b) { + while (b != 0) { + T t = b; + b = a % t; + a = t; + } + return a; +} + + } } } diff --git a/src/util/bitvector.h b/src/util/bitvector.h index 4cbcba50e..c9661c0c7 100644 --- a/src/util/bitvector.h +++ b/src/util/bitvector.h @@ -178,10 +178,23 @@ public: Integer prod = d_value * y.d_value; return BitVector(d_size, prod); } + + BitVector setBit(uint32_t i) const { + CheckArgument(i < d_size, i); + Integer res = d_value.setBit(i); + return BitVector(d_size, res); + } + + bool isBitSet(uint32_t i) const { + CheckArgument(i < d_size, i); + return d_value.isBitSet(i); + } + /** * Total division function that returns 0 when the denominator is 0. */ BitVector unsignedDivTotal (const BitVector& y) const { + CheckArgument(d_size == y.d_size, y); if (y.d_value == 0) { return BitVector(d_size, 0u); @@ -190,6 +203,7 @@ public: CheckArgument(y.d_value > 0, y); return BitVector(d_size, d_value.floorDivideQuotient(y.d_value)); } + /** * Total division function that returns 0 when the denominator is 0. */ diff --git a/src/util/index.h b/src/util/index.h index 4c03af5b0..252f7066b 100644 --- a/src/util/index.h +++ b/src/util/index.h @@ -21,6 +21,7 @@ #include <stdint.h> #include <boost/static_assert.hpp> +#include <limits> namespace CVC4 { diff --git a/src/util/integer_cln_imp.h b/src/util/integer_cln_imp.h index b5452ae00..81c0428cb 100644 --- a/src/util/integer_cln_imp.h +++ b/src/util/integer_cln_imp.h @@ -218,6 +218,16 @@ public: return Integer( d_value << ipow); } + bool isBitSet(uint32_t i) const { + return !extractBitRange(1, i).isZero(); + } + + Integer setBit(uint32_t i) const { + cln::cl_I mask(1); + mask = mask << i; + return Integer(cln::logior(d_value, mask)); + } + Integer oneExtend(uint32_t size, uint32_t amount) const { DebugCheckArgument((*this) < Integer(1).multiplyByPow2(size), size); cln::cl_byte range(amount, size); diff --git a/src/util/integer_gmp_imp.h b/src/util/integer_gmp_imp.h index 176604268..85d49f921 100644 --- a/src/util/integer_gmp_imp.h +++ b/src/util/integer_gmp_imp.h @@ -137,6 +137,7 @@ public: return *this; } + Integer bitwiseOr(const Integer& y) const { mpz_class result; mpz_ior(result.get_mpz_t(), d_value.get_mpz_t(), y.d_value.get_mpz_t()); @@ -170,6 +171,24 @@ public: return Integer( result ); } + /** + * Returns the Integer obtained by setting the ith bit of the + * current Integer to 1. + * + * @param bit + * + * @return + */ + Integer setBit(uint32_t i) const { + mpz_class res = d_value; + mpz_setbit(res.get_mpz_t(), i); + return Integer(res); + } + + bool isBitSet(uint32_t i) const { + return !extractBitRange(1, i).isZero(); + } + /** * Returns the integer with the binary representation of size bits * extended with amount 1's diff --git a/src/util/utility.h b/src/util/utility.h index 5ce185b5b..089be478d 100644 --- a/src/util/utility.h +++ b/src/util/utility.h @@ -67,7 +67,6 @@ inline InputIterator find_if_unique(InputIterator first, InputIterator last, Pre return (match2 == last) ? match : last; } - }/* CVC4 namespace */ #endif /* __CVC4__UTILITY_H */ |