diff options
Diffstat (limited to 'src/theory/bv/slicer.cpp')
-rw-r--r-- | src/theory/bv/slicer.cpp | 250 |
1 files changed, 175 insertions, 75 deletions
diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp index 5d376ea50..b24702635 100644 --- a/src/theory/bv/slicer.cpp +++ b/src/theory/bv/slicer.cpp @@ -156,11 +156,11 @@ std::string NormalForm::debugPrint(const UnionFind& uf) const { return os.str(); } /** - * UnionFind::Node + * UnionFind::EqualityNode * */ -std::string UnionFind::Node::debugPrint() const { +std::string UnionFind::EqualityNode::debugPrint() const { ostringstream os; os << "Repr " << d_edge.repr << " ["<< d_bitwidth << "] "; os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl; @@ -172,41 +172,80 @@ std::string UnionFind::Node::debugPrint() const { * UnionFind * */ -TermId UnionFind::addNode(Index bitwidth) { + +TermId UnionFind::registerTopLevelTerm(Index bitwidth) { + TermId id = mkEqualityNode(bitwidth); + d_topLevelIds.insert(id); + return id; +} + +TermId UnionFind::mkEqualityNode(Index bitwidth) { Assert (bitwidth > 0); - Node node(bitwidth); - d_nodes.push_back(node); + EqualityNode node(bitwidth); + d_equalityNodes.push_back(node); ++(d_statistics.d_numNodes); - TermId id = d_nodes.size() - 1; + TermId id = d_equalityNodes.size() - 1; // d_representatives.insert(id); ++(d_statistics.d_numRepresentatives); Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl; return id; } - -TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) { - if (isExtractTerm(topLevel)) { - ExtractTerm top = getExtractTerm(topLevel); - Index top_high = top.high; - Index top_low = top.low; - Assert (top_high - top_low + 1 > high); - high += top_low; - low += top_low; - topLevel = top.id; +/** + * Create an extract term making sure there are no nested extracts. + * + * @param id + * @param high + * @param low + * + * @return + */ +ExtractTerm UnionFind::mkExtractTerm(TermId id, Index high, Index low) { + if (d_topLevelIds.find(id) != d_topLevelIds.end()) { + return ExtractTerm(id, high, low); } - ExtractTerm extract(topLevel, high, low); + Assert (isExtractTerm(id)); + ExtractTerm top = getExtractTerm(id); + Assert (d_topLevelIds.find(top.id) != d_topLevelIds.end()); + + Index top_high = top.high; + Index top_low = top.low; + Assert (top_high - top_low + 1 > high); + high += top_low; + low += top_low; + id = top.id; + return ExtractTerm(id, high, low); +} + +/** + * Associate the given extract term with the given id. + * + * @param id + * @param extract + */ +void UnionFind::storeExtractTerm(TermId id, const ExtractTerm& extract) { if (d_extractToId.find(extract) != d_extractToId.end()) { - return d_extractToId[extract]; + Assert (d_extractToId[extract] == id); + return; } - - Assert (high >= low); - - TermId id = addNode(high - low + 1); + Debug("bv-slicer") << "UnionFind::storeExtract " << extract.debugPrint() << " => id" << id << "\n"; d_idToExtract[id] = extract; d_extractToId[extract] = id; - return id; + } + +TermId UnionFind::addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low) { + ExtractTerm extract(id, high, low); + if (d_extractToId.find(extract) != d_extractToId.end()) { + // if the extract already exists we don't need to make a new node + TermId extract_id = d_extractToId[extract]; + Assert (extract_id < d_equalityNodes.size()); + return extract_id; + } + // otherwise make an equality node for it and store the extract + TermId node_id = mkEqualityNode(bitwidth); + storeExtractTerm(node_id, extract); + return node_id; } /** @@ -215,7 +254,10 @@ TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) { * @param t1 * @param t2 */ -void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason) { +void UnionFind::unionTerms(TermId id1, TermId id2, TermId reason) { + const ExtractTerm& t1 = getExtractTerm(id1); + const ExtractTerm& t2 = getExtractTerm(id2); + Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n" << " " << t2.debugPrint() << "\n" << " with reason " << reason << endl; @@ -294,7 +336,7 @@ TermId UnionFind::findWithExplanation(TermId id, std::vector<ExplanationId>& exp void UnionFind::split(TermId id, Index i) { Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl; id = find(id); - Debug("bv-slicer-uf") << " node: " << d_nodes[id].debugPrint() << endl; + Debug("bv-slicer-uf") << " node: " << d_equalityNodes[id].debugPrint() << endl; if (i == 0 || i == getBitwidth(id)) { // nothing to do @@ -303,9 +345,15 @@ void UnionFind::split(TermId id, Index i) { Assert (i < getBitwidth(id)); if (!hasChildren(id)) { - // first time we split this term - TermId bottom_id = addExtract(id, i - 1, 0); - TermId top_id = addExtract(id, getBitwidth(id) - 1, i); + // first time we split this term + ExtractTerm bottom_extract = mkExtractTerm(id, i-1, 0); + ExtractTerm top_extract = mkExtractTerm(id, getBitwidth(id) - 1, i); + + TermId bottom_id = extractHasId(bottom_extract)? getExtractId(bottom_extract) : mkEqualityNode(i); + TermId top_id = extractHasId(top_extract)? getExtractId(top_extract) : mkEqualityNode(getBitwidth(id) - i); + storeExtractTerm(bottom_id, bottom_extract); + storeExtractTerm(top_id, top_extract); + setChildren(id, top_id, bottom_id); recordOperation(UnionFind::SPLIT, id); @@ -471,7 +519,10 @@ void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposit } -void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) { +void UnionFind::alignSlicings(TermId id1, TermId id2) { + const ExtractTerm& term1 = getExtractTerm(id1); + const ExtractTerm& term2 = getExtractTerm(id2); + Debug("bv-slicer") << "UnionFind::alignSlicings " << term1.debugPrint() << endl; Debug("bv-slicer") << " " << term2.debugPrint() << endl; NormalForm nf1(term1.getBitwidth()); @@ -519,15 +570,18 @@ void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2 } } while (changed); } + + /** * Given an extract term a[i:j] makes sure a is sliced * at indices i and j. * * @param term */ -void UnionFind::ensureSlicing(const ExtractTerm& term) { +void UnionFind::ensureSlicing(TermId t) { + ExtractTerm term = getExtractTerm(t); //Debug("bv-slicer") << "Slicer::ensureSlicing " << term.debugPrint() << endl; - TermId id = find(term.id); + TermId id = term.id; split(id, term.high + 1); split(id, term.low); } @@ -576,30 +630,69 @@ void UnionFind::getBase(TermId id, Base& base, Index offset) { getBase(id0, base, offset); } +/// getter methods for the internal nodes +TermId UnionFind::getRepr(TermId id) const { + Assert (id < d_equalityNodes.size()); + return d_equalityNodes[id].getRepr(); +} +ExplanationId UnionFind::getReason(TermId id) const { + Assert (id < d_equalityNodes.size()); + return d_equalityNodes[id].getReason(); +} +TermId UnionFind::getChild(TermId id, Index i) const { + Assert (id < d_equalityNodes.size()); + return d_equalityNodes[id].getChild(i); +} +Index UnionFind::getCutPoint(TermId id) const { + return getBitwidth(getChild(id, 0)); +} +bool UnionFind::hasChildren(TermId id) const { + Assert (id < d_equalityNodes.size()); + return d_equalityNodes[id].hasChildren(); +} + +/// setter methods for the internal nodes +void UnionFind::setRepr(TermId id, TermId new_repr, ExplanationId reason) { + Assert (id < d_equalityNodes.size()); + d_equalityNodes[id].setRepr(new_repr, reason); +} +void UnionFind::setChildren(TermId id, TermId ch1, TermId ch0) { + Assert ((ch1 == UndefinedId && ch0 == UndefinedId) || + (id < d_equalityNodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0))); + d_equalityNodes[id].setChildren(ch1, ch0); +} + /** * Slicer * */ -ExtractTerm Slicer::registerTerm(TNode node) { - Index low = 0, high = utils::getSize(node) - 1; - TNode n = node; +TermId Slicer::registerTerm(TNode node) { if (node.getKind() == kind::BITVECTOR_EXTRACT) { - n = node[0]; - high = utils::getExtractHigh(node); - low = utils::getExtractLow(node); - } - if (d_nodeToId.find(n) == d_nodeToId.end()) { - TermId id = d_unionFind.addNode(utils::getSize(n)); - d_nodeToId[n] = id; - d_idToNode[id] = n; + TNode n = node[0]; + TermId top_id = registerTopLevelTerm(n); + Index high = utils::getExtractHigh(node); + Index low = utils::getExtractLow(node); + TermId id = d_unionFind.addEqualityNode(utils::getSize(node), top_id, high, low); + return id; + } + TermId id = registerTopLevelTerm(node); + return id; +} + +TermId Slicer::registerTopLevelTerm(TNode node) { + Assert (node.getKind() != kind::BITVECTOR_EXTRACT || + node.getKind() != kind::BITVECTOR_CONCAT); + + if (d_nodeToId.find(node) == d_nodeToId.end()) { + TermId id = d_unionFind.registerTopLevelTerm(utils::getSize(node)); + d_idToNode[id] = node; + d_nodeToId[node] = id; + Debug("bv-slicer") << "Slicer::registerTopLevelTerm " << node << " => id" << id << endl; + return id; } - TermId id = d_nodeToId[n]; - d_unionFind.addExtract(id, high, low); - ExtractTerm res(id, high, low); - Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl; - return res; + return d_nodeToId[node]; } void Slicer::processEquality(TNode eq) { @@ -609,42 +702,38 @@ void Slicer::processEquality(TNode eq) { Assert (eq.getKind() == kind::EQUAL); TNode a = eq[0]; TNode b = eq[1]; - ExtractTerm a_ex= registerTerm(a); - ExtractTerm b_ex= registerTerm(b); + TermId a_id = registerTerm(a); + TermId b_id = registerTerm(b); - d_unionFind.ensureSlicing(a_ex); - d_unionFind.ensureSlicing(b_ex); + d_unionFind.ensureSlicing(a_id); + d_unionFind.ensureSlicing(b_id); - d_unionFind.alignSlicings(a_ex, b_ex); + d_unionFind.alignSlicings(a_id, b_id); - Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl; - Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl; - Debug("bv-slicer") << "Slicer::processEquality done. " << endl; + // Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl; + // Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl; + // Debug("bv-slicer") << "Slicer::processEquality done. " << endl; } void Slicer::assertEquality(TNode eq) { Assert (eq.getKind() == kind::EQUAL); - ExtractTerm a = registerTerm(eq[0]); - ExtractTerm b = registerTerm(eq[1]); + TermId a = registerTerm(eq[0]); + TermId b = registerTerm(eq[1]); ExplanationId reason = getExplanationId(eq); d_unionFind.unionTerms(a, b, reason); } -TermId Slicer::getId(TNode node) const { - __gnu_cxx::hash_map<Node, TermId, NodeHashFunction >::const_iterator it = d_nodeToId.find(node); - Assert (it != d_nodeToId.end()); - return it->second; -} void Slicer::registerEquality(TNode eq) { if (d_explanationToId.find(eq) == d_explanationToId.end()) { ExplanationId id = d_explanations.size(); d_explanations.push_back(eq); - d_explanationToId[eq] = id; + d_explanationToId[eq] = id; + Debug("bv-slicer-explanation") << "Slicer::registerEquality " << eq << " => id"<< id << "\n"; } } -void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation) { +void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation) { Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl; Index high = utils::getSize(node) - 1; @@ -672,13 +761,18 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::ve Node current = getNode(nf.decomp[i]); decomp.push_back(current); } - - - Debug("bv-slicer") << "as ["; - for (unsigned i = 0; i < decomp.size(); ++i) { - Debug("bv-slicer") << decomp[i] <<" "; + if (Debug.isOn("bv-slicer-explanation")) { + Debug("bv-slicer-explanation") << "Slicer::getBaseDecomposition for " << node << "\n" + << "as "; + for (unsigned i = 0; i < decomp.size(); ++i) { + Debug("bv-slicer-explanation") << decomp[i] <<" " ; + } + Debug("bv-slicer-explanation") << "\n Explanation : \n"; + for (unsigned i = 0; i < explanation.size(); ++i) { + Debug("bv-slicer-explanation") << " " << explanation[i] << "\n"; + } + } - Debug("bv-slicer") << "]" << endl; } @@ -754,6 +848,10 @@ void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) { ExtractTerm UnionFind::getExtractTerm(TermId id) const { + if (d_topLevelIds.find(id) != d_topLevelIds.end()) { + // if it's a top level term so we don't have an extract stored for it + return ExtractTerm(id, getBitwidth(id) - 1, 0); + } Assert (isExtractTerm(id)); return (d_idToExtract.find(id))->second; @@ -763,19 +861,21 @@ bool UnionFind::isExtractTerm(TermId id) const { return d_idToExtract.find(id) != d_idToExtract.end(); } -bool Slicer::hasNode(TermId id) const { +bool Slicer::isTopLevelNode(TermId id) const { return d_idToNode.find(id) != d_idToNode.end(); } Node Slicer::getNode(TermId id) const { - if (hasNode(id)) { + if (isTopLevelNode(id)) { return d_idToNode.find(id)->second; } - // otherwise must be an extract Assert (d_unionFind.isExtractTerm(id)); - ExtractTerm extract = d_unionFind.getExtractTerm(id); - Assert (hasNode(extract.id)); + const ExtractTerm& extract = d_unionFind.getExtractTerm(id); + Assert (isTopLevelNode(extract.id)); TNode node = d_idToNode.find(extract.id)->second; + if (extract.high == utils::getSize(node) -1 && extract.low == 0) { + return node; + } Node ex = utils::mkExtract(node, extract.high, extract.low); return ex; } |