diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2021-04-21 21:42:08 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-22 02:42:08 +0000 |
commit | 89620a0d73e7134437a39d742e91de11a08a4962 (patch) | |
tree | 46b37970a7d3f74317f8e255b6aefa9cfae127b1 /src/theory/arrays | |
parent | 90cde45ee963b994054f96f97111684cce808d82 (diff) |
Move expand definition from Theory to TheoryRewriter (#6408)
This is work towards eliminating global calls to getCurrentSmtEngine()->expandDefinition.
The next step will be to add Rewriter::expandDefinition.
Diffstat (limited to 'src/theory/arrays')
-rw-r--r-- | src/theory/arrays/theory_arrays.cpp | 58 | ||||
-rw-r--r-- | src/theory/arrays/theory_arrays.h | 2 | ||||
-rw-r--r-- | src/theory/arrays/theory_arrays_rewriter.cpp | 616 | ||||
-rw-r--r-- | src/theory/arrays/theory_arrays_rewriter.h | 458 |
4 files changed, 627 insertions, 507 deletions
diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 1a1090f68..e887feccb 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -300,7 +300,7 @@ Node TheoryArrays::solveWrite(TNode term, bool solve1, bool solve2, bool ppCheck TrustNode TheoryArrays::ppRewrite(TNode term, std::vector<SkolemLemma>& lems) { // first, see if we need to expand definitions - TrustNode texp = expandDefinition(term); + TrustNode texp = d_rewriter.expandDefinition(term); if (!texp.isNull()) { return texp; @@ -2068,62 +2068,6 @@ std::string TheoryArrays::TheoryArraysDecisionStrategy::identify() const return std::string("th_arrays_dec"); } -TrustNode TheoryArrays::expandDefinition(Node node) -{ - NodeManager* nm = NodeManager::currentNM(); - Kind kind = node.getKind(); - - /* Expand - * - * (eqrange a b i j) - * - * to - * - * forall k . i <= k <= j => a[k] = b[k] - * - */ - if (kind == kind::EQ_RANGE) - { - TNode a = node[0]; - TNode b = node[1]; - TNode i = node[2]; - TNode j = node[3]; - Node k = nm->mkBoundVar(i.getType()); - Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, k); - TypeNode type = k.getType(); - - Kind kle; - Node range; - if (type.isBitVector()) - { - kle = kind::BITVECTOR_ULE; - } - else if (type.isFloatingPoint()) - { - kle = kind::FLOATINGPOINT_LEQ; - } - else if (type.isInteger() || type.isReal()) - { - kle = kind::LEQ; - } - else - { - Unimplemented() << "Type " << type << " is not supported for predicate " - << kind; - } - - range = nm->mkNode(kind::AND, nm->mkNode(kle, i, k), nm->mkNode(kle, k, j)); - - Node eq = nm->mkNode(kind::EQUAL, - nm->mkNode(kind::SELECT, a, k), - nm->mkNode(kind::SELECT, b, k)); - Node implies = nm->mkNode(kind::IMPLIES, range, eq); - Node ret = nm->mkNode(kind::FORALL, bvl, implies); - return TrustNode::mkTrustRewrite(node, ret, nullptr); - } - return TrustNode::null(); -} - void TheoryArrays::computeRelevantTerms(std::set<Node>& termSet) { NodeManager* nm = NodeManager::currentNM(); diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h index 7cf8d52e3..f9813cd3f 100644 --- a/src/theory/arrays/theory_arrays.h +++ b/src/theory/arrays/theory_arrays.h @@ -158,8 +158,6 @@ class TheoryArrays : public Theory { std::string identify() const override { return std::string("TheoryArrays"); } - TrustNode expandDefinition(Node node) override; - ///////////////////////////////////////////////////////////////////////////// // PREPROCESSING ///////////////////////////////////////////////////////////////////////////// diff --git a/src/theory/arrays/theory_arrays_rewriter.cpp b/src/theory/arrays/theory_arrays_rewriter.cpp index 323dd0046..6269cb5dd 100644 --- a/src/theory/arrays/theory_arrays_rewriter.cpp +++ b/src/theory/arrays/theory_arrays_rewriter.cpp @@ -45,6 +45,622 @@ void setMostFrequentValueCount(TNode store, uint64_t count) { return store.setAttribute(ArrayConstantMostFrequentValueCountAttr(), count); } +Node TheoryArraysRewriter::normalizeConstant(TNode node) +{ + return normalizeConstant(node, node[1].getType().getCardinality()); +} + +// this function is called by printers when using the option "--model-u-dt-enum" +Node TheoryArraysRewriter::normalizeConstant(TNode node, Cardinality indexCard) +{ + TNode store = node[0]; + TNode index = node[1]; + TNode value = node[2]; + + std::vector<TNode> indices; + std::vector<TNode> elements; + + // Normal form for nested stores is just ordering by index - but also need + // to check if we are writing to default value + + // Go through nested stores looking for where to insert index + // Also check whether we are replacing an existing store + TNode replacedValue; + uint32_t depth = 1; + uint32_t valCount = 1; + while (store.getKind() == kind::STORE) + { + if (index == store[1]) + { + replacedValue = store[2]; + store = store[0]; + break; + } + else if (index >= store[1]) + { + break; + } + if (value == store[2]) + { + valCount += 1; + } + depth += 1; + indices.push_back(store[1]); + elements.push_back(store[2]); + store = store[0]; + } + Node n = store; + + // Get the default value at the bottom of the nested stores + while (store.getKind() == kind::STORE) + { + if (value == store[2]) + { + valCount += 1; + } + depth += 1; + store = store[0]; + } + Assert(store.getKind() == kind::STORE_ALL); + ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>(); + Node defaultValue = storeAll.getValue(); + NodeManager* nm = NodeManager::currentNM(); + + // Check if we are writing to default value - if so the store + // to index can be ignored + if (value == defaultValue) + { + if (replacedValue.isNull()) + { + // Quick exit - if writing to default value and nothing was + // replaced, we can just return node[0] + return node[0]; + } + // else rebuild the store without the replaced write and then exit + } + else + { + n = nm->mkNode(kind::STORE, n, index, value); + } + + // Build the rest of the store after inserting/deleting + while (!indices.empty()) + { + n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); + indices.pop_back(); + elements.pop_back(); + } + + // Ready to exit if write was to the default value (see previous comment) + if (value == defaultValue) + { + return n; + } + + if (indexCard.isInfinite()) + { + return n; + } + + // When index sort is finite, we have to check whether there is any value + // that is written to more than the default value. If so, it must become + // the new default value + + TNode mostFrequentValue; + uint32_t mostFrequentValueCount = 0; + store = node[0]; + if (store.getKind() == kind::STORE) + { + mostFrequentValue = getMostFrequentValue(store); + mostFrequentValueCount = getMostFrequentValueCount(store); + } + + // Compute the most frequently written value for n + if (valCount > mostFrequentValueCount + || (valCount == mostFrequentValueCount && value < mostFrequentValue)) + { + mostFrequentValue = value; + mostFrequentValueCount = valCount; + } + + // Need to make sure the default value count is larger, or the same and the + // default value is expression-order-less-than nextValue + Cardinality::CardinalityComparison compare = + indexCard.compare(mostFrequentValueCount + depth); + Assert(compare != Cardinality::UNKNOWN); + if (compare == Cardinality::GREATER + || (compare == Cardinality::EQUAL && (defaultValue < mostFrequentValue))) + { + return n; + } + + // Bad case: have to recompute value counts and/or possibly switch out + // default value + store = n; + std::unordered_set<TNode, TNodeHashFunction> indexSet; + std::unordered_map<TNode, uint32_t, TNodeHashFunction> elementsMap; + std::unordered_map<TNode, uint32_t, TNodeHashFunction>::iterator it; + uint32_t count; + uint32_t max = 0; + TNode maxValue; + while (store.getKind() == kind::STORE) + { + indices.push_back(store[1]); + indexSet.insert(store[1]); + elements.push_back(store[2]); + it = elementsMap.find(store[2]); + if (it != elementsMap.end()) + { + (*it).second = (*it).second + 1; + count = (*it).second; + } + else + { + elementsMap[store[2]] = 1; + count = 1; + } + if (count > max || (count == max && store[2] < maxValue)) + { + max = count; + maxValue = store[2]; + } + store = store[0]; + } + + Assert(depth == indices.size()); + compare = indexCard.compare(max + depth); + Assert(compare != Cardinality::UNKNOWN); + if (compare == Cardinality::GREATER + || (compare == Cardinality::EQUAL && (defaultValue < maxValue))) + { + Assert(!replacedValue.isNull() && mostFrequentValue == replacedValue); + return n; + } + + // Out of luck: have to swap out default value + + // Enumerate values from index type into newIndices and sort + std::vector<Node> newIndices; + TypeEnumerator te(index.getType()); + bool needToSort = false; + uint32_t numTe = 0; + while (!te.isFinished() + && (!indexCard.isFinite() + || numTe < indexCard.getFiniteCardinality().toUnsignedInt())) + { + if (indexSet.find(*te) == indexSet.end()) + { + if (!newIndices.empty() && (!(newIndices.back() < (*te)))) + { + needToSort = true; + } + newIndices.push_back(*te); + } + ++numTe; + ++te; + } + Assert(indexCard.compare(newIndices.size() + depth) == Cardinality::EQUAL); + if (needToSort) + { + std::sort(newIndices.begin(), newIndices.end()); + } + + n = nm->mkConst(ArrayStoreAll(node.getType(), maxValue)); + std::vector<Node>::iterator itNew = newIndices.begin(), + it_end = newIndices.end(); + while (itNew != it_end || !indices.empty()) + { + if (itNew != it_end && (indices.empty() || (*itNew) < indices.back())) + { + n = nm->mkNode(kind::STORE, n, (*itNew), defaultValue); + ++itNew; + } + else if (itNew == it_end || indices.back() < (*itNew)) + { + if (elements.back() != maxValue) + { + n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); + } + indices.pop_back(); + elements.pop_back(); + } + } + return n; +} + +RewriteResponse TheoryArraysRewriter::postRewrite(TNode node) +{ + Trace("arrays-postrewrite") + << "Arrays::postRewrite start " << node << std::endl; + switch (node.getKind()) + { + case kind::SELECT: + { + TNode store = node[0]; + TNode index = node[1]; + Node n; + bool val; + while (store.getKind() == kind::STORE) + { + if (index == store[1]) + { + val = true; + } + else if (index.isConst() && store[1].isConst()) + { + val = false; + } + else + { + n = Rewriter::rewrite(mkEqNode(store[1], index)); + if (n.getKind() != kind::CONST_BOOLEAN) + { + break; + } + val = n.getConst<bool>(); + } + if (val) + { + // select(store(a,i,v),j) = v if i = j + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << store[2] << std::endl; + return RewriteResponse(REWRITE_DONE, store[2]); + } + // select(store(a,i,v),j) = select(a,j) if i /= j + store = store[0]; + } + if (store.getKind() == kind::STORE_ALL) + { + // select(store_all(v),i) = v + ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>(); + n = storeAll.getValue(); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << n << std::endl; + Assert(n.isConst()); + return RewriteResponse(REWRITE_DONE, n); + } + else if (store != node[0]) + { + n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << n << std::endl; + return RewriteResponse(REWRITE_DONE, n); + } + break; + } + case kind::STORE: + { + TNode store = node[0]; + TNode value = node[2]; + // store(a,i,select(a,i)) = a + if (value.getKind() == kind::SELECT && value[0] == store + && value[1] == node[1]) + { + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << store << std::endl; + return RewriteResponse(REWRITE_DONE, store); + } + TNode index = node[1]; + if (store.isConst() && index.isConst() && value.isConst()) + { + // normalize constant + Node n = normalizeConstant(node); + Assert(n.isConst()); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << n << std::endl; + return RewriteResponse(REWRITE_DONE, n); + } + if (store.getKind() == kind::STORE) + { + // store(store(a,i,v),j,w) + bool val; + if (index == store[1]) + { + val = true; + } + else if (index.isConst() && store[1].isConst()) + { + val = false; + } + else + { + Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index)); + if (eqRewritten.getKind() != kind::CONST_BOOLEAN) + { + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << node << std::endl; + return RewriteResponse(REWRITE_DONE, node); + } + val = eqRewritten.getConst<bool>(); + } + NodeManager* nm = NodeManager::currentNM(); + if (val) + { + // store(store(a,i,v),i,w) = store(a,i,w) + Node result = nm->mkNode(kind::STORE, store[0], index, value); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << result << std::endl; + return RewriteResponse(REWRITE_AGAIN, result); + } + else if (index < store[1]) + { + // store(store(a,i,v),j,w) = store(store(a,j,w),i,v) + // IF i != j and j comes before i in the ordering + std::vector<TNode> indices; + std::vector<TNode> elements; + indices.push_back(store[1]); + elements.push_back(store[2]); + store = store[0]; + Node n; + while (store.getKind() == kind::STORE) + { + if (index == store[1]) + { + val = true; + } + else if (index.isConst() && store[1].isConst()) + { + val = false; + } + else + { + n = Rewriter::rewrite(mkEqNode(store[1], index)); + if (n.getKind() != kind::CONST_BOOLEAN) + { + break; + } + val = n.getConst<bool>(); + } + if (val) + { + store = store[0]; + break; + } + else if (!(index < store[1])) + { + break; + } + indices.push_back(store[1]); + elements.push_back(store[2]); + store = store[0]; + } + if (value.getKind() == kind::SELECT && value[0] == store + && value[1] == index) + { + n = store; + } + else + { + n = nm->mkNode(kind::STORE, store, index, value); + } + while (!indices.empty()) + { + n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); + indices.pop_back(); + elements.pop_back(); + } + Assert(n != node); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << n << std::endl; + return RewriteResponse(REWRITE_AGAIN, n); + } + } + break; + } + case kind::EQUAL: + { + if (node[0] == node[1]) + { + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning true" << std::endl; + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConst(true)); + } + else if (node[0].isConst() && node[1].isConst()) + { + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning false" << std::endl; + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConst(false)); + } + if (node[0] > node[1]) + { + Node newNode = + NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << newNode << std::endl; + return RewriteResponse(REWRITE_DONE, newNode); + } + break; + } + default: break; + } + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << node << std::endl; + return RewriteResponse(REWRITE_DONE, node); +} + +RewriteResponse TheoryArraysRewriter::preRewrite(TNode node) +{ + Trace("arrays-prerewrite") + << "Arrays::preRewrite start " << node << std::endl; + switch (node.getKind()) + { + case kind::SELECT: + { + TNode store = node[0]; + TNode index = node[1]; + Node n; + bool val; + while (store.getKind() == kind::STORE) + { + if (index == store[1]) + { + val = true; + } + else if (index.isConst() && store[1].isConst()) + { + val = false; + } + else + { + n = Rewriter::rewrite(mkEqNode(store[1], index)); + if (n.getKind() != kind::CONST_BOOLEAN) + { + break; + } + val = n.getConst<bool>(); + } + if (val) + { + // select(store(a,i,v),j) = v if i = j + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << store[2] << std::endl; + return RewriteResponse(REWRITE_AGAIN, store[2]); + } + // select(store(a,i,v),j) = select(a,j) if i /= j + store = store[0]; + } + if (store.getKind() == kind::STORE_ALL) + { + // select(store_all(v),i) = v + ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>(); + n = storeAll.getValue(); + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << n << std::endl; + Assert(n.isConst()); + return RewriteResponse(REWRITE_DONE, n); + } + else if (store != node[0]) + { + n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index); + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << n << std::endl; + return RewriteResponse(REWRITE_DONE, n); + } + break; + } + case kind::STORE: + { + TNode store = node[0]; + TNode value = node[2]; + // store(a,i,select(a,i)) = a + if (value.getKind() == kind::SELECT && value[0] == store + && value[1] == node[1]) + { + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << store << std::endl; + return RewriteResponse(REWRITE_AGAIN, store); + } + if (store.getKind() == kind::STORE) + { + // store(store(a,i,v),j,w) + TNode index = node[1]; + bool val; + if (index == store[1]) + { + val = true; + } + else if (index.isConst() && store[1].isConst()) + { + val = false; + } + else + { + Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index)); + if (eqRewritten.getKind() != kind::CONST_BOOLEAN) + { + break; + } + val = eqRewritten.getConst<bool>(); + } + NodeManager* nm = NodeManager::currentNM(); + if (val) + { + // store(store(a,i,v),i,w) = store(a,i,w) + Node newNode = nm->mkNode(kind::STORE, store[0], index, value); + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << newNode << std::endl; + return RewriteResponse(REWRITE_DONE, newNode); + } + } + break; + } + case kind::EQUAL: + { + if (node[0] == node[1]) + { + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning true" << std::endl; + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConst(true)); + } + break; + } + default: break; + } + + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << node << std::endl; + return RewriteResponse(REWRITE_DONE, node); +} + +TrustNode TheoryArraysRewriter::expandDefinition(Node node) +{ + NodeManager* nm = NodeManager::currentNM(); + Kind kind = node.getKind(); + + /* Expand + * + * (eqrange a b i j) + * + * to + * + * forall k . i <= k <= j => a[k] = b[k] + * + */ + if (kind == kind::EQ_RANGE) + { + TNode a = node[0]; + TNode b = node[1]; + TNode i = node[2]; + TNode j = node[3]; + Node k = nm->mkBoundVar(i.getType()); + Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, k); + TypeNode type = k.getType(); + + Kind kle; + Node range; + if (type.isBitVector()) + { + kle = kind::BITVECTOR_ULE; + } + else if (type.isFloatingPoint()) + { + kle = kind::FLOATINGPOINT_LEQ; + } + else if (type.isInteger() || type.isReal()) + { + kle = kind::LEQ; + } + else + { + Unimplemented() << "Type " << type << " is not supported for predicate " + << kind; + } + + range = nm->mkNode(kind::AND, nm->mkNode(kle, i, k), nm->mkNode(kle, k, j)); + + Node eq = nm->mkNode(kind::EQUAL, + nm->mkNode(kind::SELECT, a, k), + nm->mkNode(kind::SELECT, b, k)); + Node implies = nm->mkNode(kind::IMPLIES, range, eq); + Node ret = nm->mkNode(kind::FORALL, bvl, implies); + return TrustNode::mkTrustRewrite(node, ret, nullptr); + } + return TrustNode::null(); +} + } // namespace arrays } // namespace theory } // namespace cvc5 diff --git a/src/theory/arrays/theory_arrays_rewriter.h b/src/theory/arrays/theory_arrays_rewriter.h index 0bbfc0846..498266ce3 100644 --- a/src/theory/arrays/theory_arrays_rewriter.h +++ b/src/theory/arrays/theory_arrays_rewriter.h @@ -43,459 +43,21 @@ static inline Node mkEqNode(Node a, Node b) { class TheoryArraysRewriter : public TheoryRewriter { - static Node normalizeConstant(TNode node) { - return normalizeConstant(node, node[1].getType().getCardinality()); - } + /** + * Puts array constant node into normal form. This is so that array constants + * that are distinct nodes are semantically disequal. + */ + static Node normalizeConstant(TNode node); public: - //this function is called by printers when using the option "--model-u-dt-enum" - static Node normalizeConstant(TNode node, Cardinality indexCard) { - TNode store = node[0]; - TNode index = node[1]; - TNode value = node[2]; + /** Normalize a constant whose index type has cardinality indexCard */ + static Node normalizeConstant(TNode node, Cardinality indexCard); - std::vector<TNode> indices; - std::vector<TNode> elements; + RewriteResponse postRewrite(TNode node) override; - // Normal form for nested stores is just ordering by index - but also need - // to check if we are writing to default value + RewriteResponse preRewrite(TNode node) override; - // Go through nested stores looking for where to insert index - // Also check whether we are replacing an existing store - TNode replacedValue; - unsigned depth = 1; - unsigned valCount = 1; - while (store.getKind() == kind::STORE) { - if (index == store[1]) { - replacedValue = store[2]; - store = store[0]; - break; - } - else if (!(index < store[1])) { - break; - } - if (value == store[2]) { - valCount += 1; - } - depth += 1; - indices.push_back(store[1]); - elements.push_back(store[2]); - store = store[0]; - } - Node n = store; - - // Get the default value at the bottom of the nested stores - while (store.getKind() == kind::STORE) { - if (value == store[2]) { - valCount += 1; - } - depth += 1; - store = store[0]; - } - Assert(store.getKind() == kind::STORE_ALL); - ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>(); - Node defaultValue = storeAll.getValue(); - NodeManager* nm = NodeManager::currentNM(); - - // Check if we are writing to default value - if so the store - // to index can be ignored - if (value == defaultValue) { - if (replacedValue.isNull()) { - // Quick exit - if writing to default value and nothing was - // replaced, we can just return node[0] - return node[0]; - } - // else rebuild the store without the replaced write and then exit - } - else { - n = nm->mkNode(kind::STORE, n, index, value); - } - - // Build the rest of the store after inserting/deleting - while (!indices.empty()) { - n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); - indices.pop_back(); - elements.pop_back(); - } - - // Ready to exit if write was to the default value (see previous comment) - if (value == defaultValue) { - return n; - } - - if (indexCard.isInfinite()) { - return n; - } - - // When index sort is finite, we have to check whether there is any value - // that is written to more than the default value. If so, it must become - // the new default value - - TNode mostFrequentValue; - unsigned mostFrequentValueCount = 0; - store = node[0]; - if (store.getKind() == kind::STORE) { - mostFrequentValue = getMostFrequentValue(store); - mostFrequentValueCount = getMostFrequentValueCount(store); - } - - // Compute the most frequently written value for n - if (valCount > mostFrequentValueCount || - (valCount == mostFrequentValueCount && value < mostFrequentValue)) { - mostFrequentValue = value; - mostFrequentValueCount = valCount; - } - - // Need to make sure the default value count is larger, or the same and the default value is expression-order-less-than nextValue - Cardinality::CardinalityComparison compare = indexCard.compare(mostFrequentValueCount + depth); - Assert(compare != Cardinality::UNKNOWN); - if (compare == Cardinality::GREATER || - (compare == Cardinality::EQUAL && (defaultValue < mostFrequentValue))) { - return n; - } - - // Bad case: have to recompute value counts and/or possibly switch out - // default value - store = n; - std::unordered_set<TNode, TNodeHashFunction> indexSet; - std::unordered_map<TNode, unsigned, TNodeHashFunction> elementsMap; - std::unordered_map<TNode, unsigned, TNodeHashFunction>::iterator it; - unsigned count; - unsigned max = 0; - TNode maxValue; - while (store.getKind() == kind::STORE) { - indices.push_back(store[1]); - indexSet.insert(store[1]); - elements.push_back(store[2]); - it = elementsMap.find(store[2]); - if (it != elementsMap.end()) { - (*it).second = (*it).second + 1; - count = (*it).second; - } - else { - elementsMap[store[2]] = 1; - count = 1; - } - if (count > max || - (count == max && store[2] < maxValue)) { - max = count; - maxValue = store[2]; - } - store = store[0]; - } - - Assert(depth == indices.size()); - compare = indexCard.compare(max + depth); - Assert(compare != Cardinality::UNKNOWN); - if (compare == Cardinality::GREATER || - (compare == Cardinality::EQUAL && (defaultValue < maxValue))) { - Assert(!replacedValue.isNull() && mostFrequentValue == replacedValue); - return n; - } - - // Out of luck: have to swap out default value - - // Enumerate values from index type into newIndices and sort - std::vector<Node> newIndices; - TypeEnumerator te(index.getType()); - bool needToSort = false; - unsigned numTe = 0; - while (!te.isFinished() && (!indexCard.isFinite() || numTe<indexCard.getFiniteCardinality().toUnsignedInt())) { - if (indexSet.find(*te) == indexSet.end()) { - if (!newIndices.empty() && (!(newIndices.back() < (*te)))) { - needToSort = true; - } - newIndices.push_back(*te); - } - ++numTe; - ++te; - } - Assert(indexCard.compare(newIndices.size() + depth) == Cardinality::EQUAL); - if (needToSort) { - std::sort(newIndices.begin(), newIndices.end()); - } - - n = nm->mkConst(ArrayStoreAll(node.getType(), maxValue)); - std::vector<Node>::iterator itNew = newIndices.begin(), it_end = newIndices.end(); - while (itNew != it_end || !indices.empty()) { - if (itNew != it_end && (indices.empty() || (*itNew) < indices.back())) { - n = nm->mkNode(kind::STORE, n, (*itNew), defaultValue); - ++itNew; - } - else if (itNew == it_end || indices.back() < (*itNew)) { - if (elements.back() != maxValue) { - n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); - } - indices.pop_back(); - elements.pop_back(); - } - } - return n; - } - - public: - RewriteResponse postRewrite(TNode node) override - { - Trace("arrays-postrewrite") << "Arrays::postRewrite start " << node << std::endl; - switch (node.getKind()) { - case kind::SELECT: { - TNode store = node[0]; - TNode index = node[1]; - Node n; - bool val; - while (store.getKind() == kind::STORE) { - if (index == store[1]) { - val = true; - } - else if (index.isConst() && store[1].isConst()) { - val = false; - } - else { - n = Rewriter::rewrite(mkEqNode(store[1], index)); - if (n.getKind() != kind::CONST_BOOLEAN) { - break; - } - val = n.getConst<bool>(); - } - if (val) { - // select(store(a,i,v),j) = v if i = j - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << store[2] << std::endl; - return RewriteResponse(REWRITE_DONE, store[2]); - } - // select(store(a,i,v),j) = select(a,j) if i /= j - store = store[0]; - } - if (store.getKind() == kind::STORE_ALL) { - // select(store_all(v),i) = v - ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>(); - n = storeAll.getValue(); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl; - Assert(n.isConst()); - return RewriteResponse(REWRITE_DONE, n); - } - else if (store != node[0]) { - n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl; - return RewriteResponse(REWRITE_DONE, n); - } - break; - } - case kind::STORE: { - TNode store = node[0]; - TNode value = node[2]; - // store(a,i,select(a,i)) = a - if (value.getKind() == kind::SELECT && - value[0] == store && - value[1] == node[1]) { - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << store << std::endl; - return RewriteResponse(REWRITE_DONE, store); - } - TNode index = node[1]; - if (store.isConst() && index.isConst() && value.isConst()) { - // normalize constant - Node n = normalizeConstant(node); - Assert(n.isConst()); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl; - return RewriteResponse(REWRITE_DONE, n); - } - if (store.getKind() == kind::STORE) { - // store(store(a,i,v),j,w) - bool val; - if (index == store[1]) { - val = true; - } - else if (index.isConst() && store[1].isConst()) { - val = false; - } - else { - Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index)); - if (eqRewritten.getKind() != kind::CONST_BOOLEAN) { - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << node << std::endl; - return RewriteResponse(REWRITE_DONE, node); - } - val = eqRewritten.getConst<bool>(); - } - NodeManager* nm = NodeManager::currentNM(); - if (val) { - // store(store(a,i,v),i,w) = store(a,i,w) - Node result = nm->mkNode(kind::STORE, store[0], index, value); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << result << std::endl; - return RewriteResponse(REWRITE_AGAIN, result); - } - else if (index < store[1]) { - // store(store(a,i,v),j,w) = store(store(a,j,w),i,v) - // IF i != j and j comes before i in the ordering - std::vector<TNode> indices; - std::vector<TNode> elements; - indices.push_back(store[1]); - elements.push_back(store[2]); - store = store[0]; - Node n; - while (store.getKind() == kind::STORE) { - if (index == store[1]) { - val = true; - } - else if (index.isConst() && store[1].isConst()) { - val = false; - } - else { - n = Rewriter::rewrite(mkEqNode(store[1], index)); - if (n.getKind() != kind::CONST_BOOLEAN) { - break; - } - val = n.getConst<bool>(); - } - if (val) { - store = store[0]; - break; - } - else if (!(index < store[1])) { - break; - } - indices.push_back(store[1]); - elements.push_back(store[2]); - store = store[0]; - } - if (value.getKind() == kind::SELECT && - value[0] == store && - value[1] == index) { - n = store; - } - else { - n = nm->mkNode(kind::STORE, store, index, value); - } - while (!indices.empty()) { - n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); - indices.pop_back(); - elements.pop_back(); - } - Assert(n != node); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl; - return RewriteResponse(REWRITE_AGAIN, n); - } - } - break; - } - case kind::EQUAL:{ - if(node[0] == node[1]) { - Trace("arrays-postrewrite") << "Arrays::postRewrite returning true" << std::endl; - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); - } - else if (node[0].isConst() && node[1].isConst()) { - Trace("arrays-postrewrite") << "Arrays::postRewrite returning false" << std::endl; - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(false)); - } - if (node[0] > node[1]) { - Node newNode = NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << newNode << std::endl; - return RewriteResponse(REWRITE_DONE, newNode); - } - break; - } - default: - break; - } - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << node << std::endl; - return RewriteResponse(REWRITE_DONE, node); - } - - RewriteResponse preRewrite(TNode node) override - { - Trace("arrays-prerewrite") << "Arrays::preRewrite start " << node << std::endl; - switch (node.getKind()) { - case kind::SELECT: { - TNode store = node[0]; - TNode index = node[1]; - Node n; - bool val; - while (store.getKind() == kind::STORE) { - if (index == store[1]) { - val = true; - } - else if (index.isConst() && store[1].isConst()) { - val = false; - } - else { - n = Rewriter::rewrite(mkEqNode(store[1], index)); - if (n.getKind() != kind::CONST_BOOLEAN) { - break; - } - val = n.getConst<bool>(); - } - if (val) { - // select(store(a,i,v),j) = v if i = j - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store[2] << std::endl; - return RewriteResponse(REWRITE_AGAIN, store[2]); - } - // select(store(a,i,v),j) = select(a,j) if i /= j - store = store[0]; - } - if (store.getKind() == kind::STORE_ALL) { - // select(store_all(v),i) = v - ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>(); - n = storeAll.getValue(); - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl; - Assert(n.isConst()); - return RewriteResponse(REWRITE_DONE, n); - } - else if (store != node[0]) { - n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index); - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl; - return RewriteResponse(REWRITE_DONE, n); - } - break; - } - case kind::STORE: { - TNode store = node[0]; - TNode value = node[2]; - // store(a,i,select(a,i)) = a - if (value.getKind() == kind::SELECT && - value[0] == store && - value[1] == node[1]) { - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store << std::endl; - return RewriteResponse(REWRITE_AGAIN, store); - } - if (store.getKind() == kind::STORE) { - // store(store(a,i,v),j,w) - TNode index = node[1]; - bool val; - if (index == store[1]) { - val = true; - } - else if (index.isConst() && store[1].isConst()) { - val = false; - } - else { - Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index)); - if (eqRewritten.getKind() != kind::CONST_BOOLEAN) { - break; - } - val = eqRewritten.getConst<bool>(); - } - NodeManager* nm = NodeManager::currentNM(); - if (val) { - // store(store(a,i,v),i,w) = store(a,i,w) - Node newNode = nm->mkNode(kind::STORE, store[0], index, value); - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << newNode << std::endl; - return RewriteResponse(REWRITE_DONE, newNode); - } - } - break; - } - case kind::EQUAL:{ - if(node[0] == node[1]) { - Trace("arrays-prerewrite") << "Arrays::preRewrite returning true" << std::endl; - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); - } - break; - } - default: - break; - } - - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << node << std::endl; - return RewriteResponse(REWRITE_DONE, node); - } + TrustNode expandDefinition(Node node) override; static inline void init() {} static inline void shutdown() {} |