summaryrefslogtreecommitdiff
path: root/src/theory/arrays
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2021-04-21 21:42:08 -0500
committerGitHub <noreply@github.com>2021-04-22 02:42:08 +0000
commit89620a0d73e7134437a39d742e91de11a08a4962 (patch)
tree46b37970a7d3f74317f8e255b6aefa9cfae127b1 /src/theory/arrays
parent90cde45ee963b994054f96f97111684cce808d82 (diff)
Move expand definition from Theory to TheoryRewriter (#6408)
This is work towards eliminating global calls to getCurrentSmtEngine()->expandDefinition. The next step will be to add Rewriter::expandDefinition.
Diffstat (limited to 'src/theory/arrays')
-rw-r--r--src/theory/arrays/theory_arrays.cpp58
-rw-r--r--src/theory/arrays/theory_arrays.h2
-rw-r--r--src/theory/arrays/theory_arrays_rewriter.cpp616
-rw-r--r--src/theory/arrays/theory_arrays_rewriter.h458
4 files changed, 627 insertions, 507 deletions
diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp
index 1a1090f68..e887feccb 100644
--- a/src/theory/arrays/theory_arrays.cpp
+++ b/src/theory/arrays/theory_arrays.cpp
@@ -300,7 +300,7 @@ Node TheoryArrays::solveWrite(TNode term, bool solve1, bool solve2, bool ppCheck
TrustNode TheoryArrays::ppRewrite(TNode term, std::vector<SkolemLemma>& lems)
{
// first, see if we need to expand definitions
- TrustNode texp = expandDefinition(term);
+ TrustNode texp = d_rewriter.expandDefinition(term);
if (!texp.isNull())
{
return texp;
@@ -2068,62 +2068,6 @@ std::string TheoryArrays::TheoryArraysDecisionStrategy::identify() const
return std::string("th_arrays_dec");
}
-TrustNode TheoryArrays::expandDefinition(Node node)
-{
- NodeManager* nm = NodeManager::currentNM();
- Kind kind = node.getKind();
-
- /* Expand
- *
- * (eqrange a b i j)
- *
- * to
- *
- * forall k . i <= k <= j => a[k] = b[k]
- *
- */
- if (kind == kind::EQ_RANGE)
- {
- TNode a = node[0];
- TNode b = node[1];
- TNode i = node[2];
- TNode j = node[3];
- Node k = nm->mkBoundVar(i.getType());
- Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, k);
- TypeNode type = k.getType();
-
- Kind kle;
- Node range;
- if (type.isBitVector())
- {
- kle = kind::BITVECTOR_ULE;
- }
- else if (type.isFloatingPoint())
- {
- kle = kind::FLOATINGPOINT_LEQ;
- }
- else if (type.isInteger() || type.isReal())
- {
- kle = kind::LEQ;
- }
- else
- {
- Unimplemented() << "Type " << type << " is not supported for predicate "
- << kind;
- }
-
- range = nm->mkNode(kind::AND, nm->mkNode(kle, i, k), nm->mkNode(kle, k, j));
-
- Node eq = nm->mkNode(kind::EQUAL,
- nm->mkNode(kind::SELECT, a, k),
- nm->mkNode(kind::SELECT, b, k));
- Node implies = nm->mkNode(kind::IMPLIES, range, eq);
- Node ret = nm->mkNode(kind::FORALL, bvl, implies);
- return TrustNode::mkTrustRewrite(node, ret, nullptr);
- }
- return TrustNode::null();
-}
-
void TheoryArrays::computeRelevantTerms(std::set<Node>& termSet)
{
NodeManager* nm = NodeManager::currentNM();
diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h
index 7cf8d52e3..f9813cd3f 100644
--- a/src/theory/arrays/theory_arrays.h
+++ b/src/theory/arrays/theory_arrays.h
@@ -158,8 +158,6 @@ class TheoryArrays : public Theory {
std::string identify() const override { return std::string("TheoryArrays"); }
- TrustNode expandDefinition(Node node) override;
-
/////////////////////////////////////////////////////////////////////////////
// PREPROCESSING
/////////////////////////////////////////////////////////////////////////////
diff --git a/src/theory/arrays/theory_arrays_rewriter.cpp b/src/theory/arrays/theory_arrays_rewriter.cpp
index 323dd0046..6269cb5dd 100644
--- a/src/theory/arrays/theory_arrays_rewriter.cpp
+++ b/src/theory/arrays/theory_arrays_rewriter.cpp
@@ -45,6 +45,622 @@ void setMostFrequentValueCount(TNode store, uint64_t count) {
return store.setAttribute(ArrayConstantMostFrequentValueCountAttr(), count);
}
+Node TheoryArraysRewriter::normalizeConstant(TNode node)
+{
+ return normalizeConstant(node, node[1].getType().getCardinality());
+}
+
+// this function is called by printers when using the option "--model-u-dt-enum"
+Node TheoryArraysRewriter::normalizeConstant(TNode node, Cardinality indexCard)
+{
+ TNode store = node[0];
+ TNode index = node[1];
+ TNode value = node[2];
+
+ std::vector<TNode> indices;
+ std::vector<TNode> elements;
+
+ // Normal form for nested stores is just ordering by index - but also need
+ // to check if we are writing to default value
+
+ // Go through nested stores looking for where to insert index
+ // Also check whether we are replacing an existing store
+ TNode replacedValue;
+ uint32_t depth = 1;
+ uint32_t valCount = 1;
+ while (store.getKind() == kind::STORE)
+ {
+ if (index == store[1])
+ {
+ replacedValue = store[2];
+ store = store[0];
+ break;
+ }
+ else if (index >= store[1])
+ {
+ break;
+ }
+ if (value == store[2])
+ {
+ valCount += 1;
+ }
+ depth += 1;
+ indices.push_back(store[1]);
+ elements.push_back(store[2]);
+ store = store[0];
+ }
+ Node n = store;
+
+ // Get the default value at the bottom of the nested stores
+ while (store.getKind() == kind::STORE)
+ {
+ if (value == store[2])
+ {
+ valCount += 1;
+ }
+ depth += 1;
+ store = store[0];
+ }
+ Assert(store.getKind() == kind::STORE_ALL);
+ ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
+ Node defaultValue = storeAll.getValue();
+ NodeManager* nm = NodeManager::currentNM();
+
+ // Check if we are writing to default value - if so the store
+ // to index can be ignored
+ if (value == defaultValue)
+ {
+ if (replacedValue.isNull())
+ {
+ // Quick exit - if writing to default value and nothing was
+ // replaced, we can just return node[0]
+ return node[0];
+ }
+ // else rebuild the store without the replaced write and then exit
+ }
+ else
+ {
+ n = nm->mkNode(kind::STORE, n, index, value);
+ }
+
+ // Build the rest of the store after inserting/deleting
+ while (!indices.empty())
+ {
+ n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
+ indices.pop_back();
+ elements.pop_back();
+ }
+
+ // Ready to exit if write was to the default value (see previous comment)
+ if (value == defaultValue)
+ {
+ return n;
+ }
+
+ if (indexCard.isInfinite())
+ {
+ return n;
+ }
+
+ // When index sort is finite, we have to check whether there is any value
+ // that is written to more than the default value. If so, it must become
+ // the new default value
+
+ TNode mostFrequentValue;
+ uint32_t mostFrequentValueCount = 0;
+ store = node[0];
+ if (store.getKind() == kind::STORE)
+ {
+ mostFrequentValue = getMostFrequentValue(store);
+ mostFrequentValueCount = getMostFrequentValueCount(store);
+ }
+
+ // Compute the most frequently written value for n
+ if (valCount > mostFrequentValueCount
+ || (valCount == mostFrequentValueCount && value < mostFrequentValue))
+ {
+ mostFrequentValue = value;
+ mostFrequentValueCount = valCount;
+ }
+
+ // Need to make sure the default value count is larger, or the same and the
+ // default value is expression-order-less-than nextValue
+ Cardinality::CardinalityComparison compare =
+ indexCard.compare(mostFrequentValueCount + depth);
+ Assert(compare != Cardinality::UNKNOWN);
+ if (compare == Cardinality::GREATER
+ || (compare == Cardinality::EQUAL && (defaultValue < mostFrequentValue)))
+ {
+ return n;
+ }
+
+ // Bad case: have to recompute value counts and/or possibly switch out
+ // default value
+ store = n;
+ std::unordered_set<TNode, TNodeHashFunction> indexSet;
+ std::unordered_map<TNode, uint32_t, TNodeHashFunction> elementsMap;
+ std::unordered_map<TNode, uint32_t, TNodeHashFunction>::iterator it;
+ uint32_t count;
+ uint32_t max = 0;
+ TNode maxValue;
+ while (store.getKind() == kind::STORE)
+ {
+ indices.push_back(store[1]);
+ indexSet.insert(store[1]);
+ elements.push_back(store[2]);
+ it = elementsMap.find(store[2]);
+ if (it != elementsMap.end())
+ {
+ (*it).second = (*it).second + 1;
+ count = (*it).second;
+ }
+ else
+ {
+ elementsMap[store[2]] = 1;
+ count = 1;
+ }
+ if (count > max || (count == max && store[2] < maxValue))
+ {
+ max = count;
+ maxValue = store[2];
+ }
+ store = store[0];
+ }
+
+ Assert(depth == indices.size());
+ compare = indexCard.compare(max + depth);
+ Assert(compare != Cardinality::UNKNOWN);
+ if (compare == Cardinality::GREATER
+ || (compare == Cardinality::EQUAL && (defaultValue < maxValue)))
+ {
+ Assert(!replacedValue.isNull() && mostFrequentValue == replacedValue);
+ return n;
+ }
+
+ // Out of luck: have to swap out default value
+
+ // Enumerate values from index type into newIndices and sort
+ std::vector<Node> newIndices;
+ TypeEnumerator te(index.getType());
+ bool needToSort = false;
+ uint32_t numTe = 0;
+ while (!te.isFinished()
+ && (!indexCard.isFinite()
+ || numTe < indexCard.getFiniteCardinality().toUnsignedInt()))
+ {
+ if (indexSet.find(*te) == indexSet.end())
+ {
+ if (!newIndices.empty() && (!(newIndices.back() < (*te))))
+ {
+ needToSort = true;
+ }
+ newIndices.push_back(*te);
+ }
+ ++numTe;
+ ++te;
+ }
+ Assert(indexCard.compare(newIndices.size() + depth) == Cardinality::EQUAL);
+ if (needToSort)
+ {
+ std::sort(newIndices.begin(), newIndices.end());
+ }
+
+ n = nm->mkConst(ArrayStoreAll(node.getType(), maxValue));
+ std::vector<Node>::iterator itNew = newIndices.begin(),
+ it_end = newIndices.end();
+ while (itNew != it_end || !indices.empty())
+ {
+ if (itNew != it_end && (indices.empty() || (*itNew) < indices.back()))
+ {
+ n = nm->mkNode(kind::STORE, n, (*itNew), defaultValue);
+ ++itNew;
+ }
+ else if (itNew == it_end || indices.back() < (*itNew))
+ {
+ if (elements.back() != maxValue)
+ {
+ n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
+ }
+ indices.pop_back();
+ elements.pop_back();
+ }
+ }
+ return n;
+}
+
+RewriteResponse TheoryArraysRewriter::postRewrite(TNode node)
+{
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite start " << node << std::endl;
+ switch (node.getKind())
+ {
+ case kind::SELECT:
+ {
+ TNode store = node[0];
+ TNode index = node[1];
+ Node n;
+ bool val;
+ while (store.getKind() == kind::STORE)
+ {
+ if (index == store[1])
+ {
+ val = true;
+ }
+ else if (index.isConst() && store[1].isConst())
+ {
+ val = false;
+ }
+ else
+ {
+ n = Rewriter::rewrite(mkEqNode(store[1], index));
+ if (n.getKind() != kind::CONST_BOOLEAN)
+ {
+ break;
+ }
+ val = n.getConst<bool>();
+ }
+ if (val)
+ {
+ // select(store(a,i,v),j) = v if i = j
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning " << store[2] << std::endl;
+ return RewriteResponse(REWRITE_DONE, store[2]);
+ }
+ // select(store(a,i,v),j) = select(a,j) if i /= j
+ store = store[0];
+ }
+ if (store.getKind() == kind::STORE_ALL)
+ {
+ // select(store_all(v),i) = v
+ ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
+ n = storeAll.getValue();
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning " << n << std::endl;
+ Assert(n.isConst());
+ return RewriteResponse(REWRITE_DONE, n);
+ }
+ else if (store != node[0])
+ {
+ n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning " << n << std::endl;
+ return RewriteResponse(REWRITE_DONE, n);
+ }
+ break;
+ }
+ case kind::STORE:
+ {
+ TNode store = node[0];
+ TNode value = node[2];
+ // store(a,i,select(a,i)) = a
+ if (value.getKind() == kind::SELECT && value[0] == store
+ && value[1] == node[1])
+ {
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning " << store << std::endl;
+ return RewriteResponse(REWRITE_DONE, store);
+ }
+ TNode index = node[1];
+ if (store.isConst() && index.isConst() && value.isConst())
+ {
+ // normalize constant
+ Node n = normalizeConstant(node);
+ Assert(n.isConst());
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning " << n << std::endl;
+ return RewriteResponse(REWRITE_DONE, n);
+ }
+ if (store.getKind() == kind::STORE)
+ {
+ // store(store(a,i,v),j,w)
+ bool val;
+ if (index == store[1])
+ {
+ val = true;
+ }
+ else if (index.isConst() && store[1].isConst())
+ {
+ val = false;
+ }
+ else
+ {
+ Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index));
+ if (eqRewritten.getKind() != kind::CONST_BOOLEAN)
+ {
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning " << node << std::endl;
+ return RewriteResponse(REWRITE_DONE, node);
+ }
+ val = eqRewritten.getConst<bool>();
+ }
+ NodeManager* nm = NodeManager::currentNM();
+ if (val)
+ {
+ // store(store(a,i,v),i,w) = store(a,i,w)
+ Node result = nm->mkNode(kind::STORE, store[0], index, value);
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning " << result << std::endl;
+ return RewriteResponse(REWRITE_AGAIN, result);
+ }
+ else if (index < store[1])
+ {
+ // store(store(a,i,v),j,w) = store(store(a,j,w),i,v)
+ // IF i != j and j comes before i in the ordering
+ std::vector<TNode> indices;
+ std::vector<TNode> elements;
+ indices.push_back(store[1]);
+ elements.push_back(store[2]);
+ store = store[0];
+ Node n;
+ while (store.getKind() == kind::STORE)
+ {
+ if (index == store[1])
+ {
+ val = true;
+ }
+ else if (index.isConst() && store[1].isConst())
+ {
+ val = false;
+ }
+ else
+ {
+ n = Rewriter::rewrite(mkEqNode(store[1], index));
+ if (n.getKind() != kind::CONST_BOOLEAN)
+ {
+ break;
+ }
+ val = n.getConst<bool>();
+ }
+ if (val)
+ {
+ store = store[0];
+ break;
+ }
+ else if (!(index < store[1]))
+ {
+ break;
+ }
+ indices.push_back(store[1]);
+ elements.push_back(store[2]);
+ store = store[0];
+ }
+ if (value.getKind() == kind::SELECT && value[0] == store
+ && value[1] == index)
+ {
+ n = store;
+ }
+ else
+ {
+ n = nm->mkNode(kind::STORE, store, index, value);
+ }
+ while (!indices.empty())
+ {
+ n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
+ indices.pop_back();
+ elements.pop_back();
+ }
+ Assert(n != node);
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning " << n << std::endl;
+ return RewriteResponse(REWRITE_AGAIN, n);
+ }
+ }
+ break;
+ }
+ case kind::EQUAL:
+ {
+ if (node[0] == node[1])
+ {
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning true" << std::endl;
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(true));
+ }
+ else if (node[0].isConst() && node[1].isConst())
+ {
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning false" << std::endl;
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(false));
+ }
+ if (node[0] > node[1])
+ {
+ Node newNode =
+ NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]);
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning " << newNode << std::endl;
+ return RewriteResponse(REWRITE_DONE, newNode);
+ }
+ break;
+ }
+ default: break;
+ }
+ Trace("arrays-postrewrite")
+ << "Arrays::postRewrite returning " << node << std::endl;
+ return RewriteResponse(REWRITE_DONE, node);
+}
+
+RewriteResponse TheoryArraysRewriter::preRewrite(TNode node)
+{
+ Trace("arrays-prerewrite")
+ << "Arrays::preRewrite start " << node << std::endl;
+ switch (node.getKind())
+ {
+ case kind::SELECT:
+ {
+ TNode store = node[0];
+ TNode index = node[1];
+ Node n;
+ bool val;
+ while (store.getKind() == kind::STORE)
+ {
+ if (index == store[1])
+ {
+ val = true;
+ }
+ else if (index.isConst() && store[1].isConst())
+ {
+ val = false;
+ }
+ else
+ {
+ n = Rewriter::rewrite(mkEqNode(store[1], index));
+ if (n.getKind() != kind::CONST_BOOLEAN)
+ {
+ break;
+ }
+ val = n.getConst<bool>();
+ }
+ if (val)
+ {
+ // select(store(a,i,v),j) = v if i = j
+ Trace("arrays-prerewrite")
+ << "Arrays::preRewrite returning " << store[2] << std::endl;
+ return RewriteResponse(REWRITE_AGAIN, store[2]);
+ }
+ // select(store(a,i,v),j) = select(a,j) if i /= j
+ store = store[0];
+ }
+ if (store.getKind() == kind::STORE_ALL)
+ {
+ // select(store_all(v),i) = v
+ ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
+ n = storeAll.getValue();
+ Trace("arrays-prerewrite")
+ << "Arrays::preRewrite returning " << n << std::endl;
+ Assert(n.isConst());
+ return RewriteResponse(REWRITE_DONE, n);
+ }
+ else if (store != node[0])
+ {
+ n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
+ Trace("arrays-prerewrite")
+ << "Arrays::preRewrite returning " << n << std::endl;
+ return RewriteResponse(REWRITE_DONE, n);
+ }
+ break;
+ }
+ case kind::STORE:
+ {
+ TNode store = node[0];
+ TNode value = node[2];
+ // store(a,i,select(a,i)) = a
+ if (value.getKind() == kind::SELECT && value[0] == store
+ && value[1] == node[1])
+ {
+ Trace("arrays-prerewrite")
+ << "Arrays::preRewrite returning " << store << std::endl;
+ return RewriteResponse(REWRITE_AGAIN, store);
+ }
+ if (store.getKind() == kind::STORE)
+ {
+ // store(store(a,i,v),j,w)
+ TNode index = node[1];
+ bool val;
+ if (index == store[1])
+ {
+ val = true;
+ }
+ else if (index.isConst() && store[1].isConst())
+ {
+ val = false;
+ }
+ else
+ {
+ Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index));
+ if (eqRewritten.getKind() != kind::CONST_BOOLEAN)
+ {
+ break;
+ }
+ val = eqRewritten.getConst<bool>();
+ }
+ NodeManager* nm = NodeManager::currentNM();
+ if (val)
+ {
+ // store(store(a,i,v),i,w) = store(a,i,w)
+ Node newNode = nm->mkNode(kind::STORE, store[0], index, value);
+ Trace("arrays-prerewrite")
+ << "Arrays::preRewrite returning " << newNode << std::endl;
+ return RewriteResponse(REWRITE_DONE, newNode);
+ }
+ }
+ break;
+ }
+ case kind::EQUAL:
+ {
+ if (node[0] == node[1])
+ {
+ Trace("arrays-prerewrite")
+ << "Arrays::preRewrite returning true" << std::endl;
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(true));
+ }
+ break;
+ }
+ default: break;
+ }
+
+ Trace("arrays-prerewrite")
+ << "Arrays::preRewrite returning " << node << std::endl;
+ return RewriteResponse(REWRITE_DONE, node);
+}
+
+TrustNode TheoryArraysRewriter::expandDefinition(Node node)
+{
+ NodeManager* nm = NodeManager::currentNM();
+ Kind kind = node.getKind();
+
+ /* Expand
+ *
+ * (eqrange a b i j)
+ *
+ * to
+ *
+ * forall k . i <= k <= j => a[k] = b[k]
+ *
+ */
+ if (kind == kind::EQ_RANGE)
+ {
+ TNode a = node[0];
+ TNode b = node[1];
+ TNode i = node[2];
+ TNode j = node[3];
+ Node k = nm->mkBoundVar(i.getType());
+ Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, k);
+ TypeNode type = k.getType();
+
+ Kind kle;
+ Node range;
+ if (type.isBitVector())
+ {
+ kle = kind::BITVECTOR_ULE;
+ }
+ else if (type.isFloatingPoint())
+ {
+ kle = kind::FLOATINGPOINT_LEQ;
+ }
+ else if (type.isInteger() || type.isReal())
+ {
+ kle = kind::LEQ;
+ }
+ else
+ {
+ Unimplemented() << "Type " << type << " is not supported for predicate "
+ << kind;
+ }
+
+ range = nm->mkNode(kind::AND, nm->mkNode(kle, i, k), nm->mkNode(kle, k, j));
+
+ Node eq = nm->mkNode(kind::EQUAL,
+ nm->mkNode(kind::SELECT, a, k),
+ nm->mkNode(kind::SELECT, b, k));
+ Node implies = nm->mkNode(kind::IMPLIES, range, eq);
+ Node ret = nm->mkNode(kind::FORALL, bvl, implies);
+ return TrustNode::mkTrustRewrite(node, ret, nullptr);
+ }
+ return TrustNode::null();
+}
+
} // namespace arrays
} // namespace theory
} // namespace cvc5
diff --git a/src/theory/arrays/theory_arrays_rewriter.h b/src/theory/arrays/theory_arrays_rewriter.h
index 0bbfc0846..498266ce3 100644
--- a/src/theory/arrays/theory_arrays_rewriter.h
+++ b/src/theory/arrays/theory_arrays_rewriter.h
@@ -43,459 +43,21 @@ static inline Node mkEqNode(Node a, Node b) {
class TheoryArraysRewriter : public TheoryRewriter
{
- static Node normalizeConstant(TNode node) {
- return normalizeConstant(node, node[1].getType().getCardinality());
- }
+ /**
+ * Puts array constant node into normal form. This is so that array constants
+ * that are distinct nodes are semantically disequal.
+ */
+ static Node normalizeConstant(TNode node);
public:
- //this function is called by printers when using the option "--model-u-dt-enum"
- static Node normalizeConstant(TNode node, Cardinality indexCard) {
- TNode store = node[0];
- TNode index = node[1];
- TNode value = node[2];
+ /** Normalize a constant whose index type has cardinality indexCard */
+ static Node normalizeConstant(TNode node, Cardinality indexCard);
- std::vector<TNode> indices;
- std::vector<TNode> elements;
+ RewriteResponse postRewrite(TNode node) override;
- // Normal form for nested stores is just ordering by index - but also need
- // to check if we are writing to default value
+ RewriteResponse preRewrite(TNode node) override;
- // Go through nested stores looking for where to insert index
- // Also check whether we are replacing an existing store
- TNode replacedValue;
- unsigned depth = 1;
- unsigned valCount = 1;
- while (store.getKind() == kind::STORE) {
- if (index == store[1]) {
- replacedValue = store[2];
- store = store[0];
- break;
- }
- else if (!(index < store[1])) {
- break;
- }
- if (value == store[2]) {
- valCount += 1;
- }
- depth += 1;
- indices.push_back(store[1]);
- elements.push_back(store[2]);
- store = store[0];
- }
- Node n = store;
-
- // Get the default value at the bottom of the nested stores
- while (store.getKind() == kind::STORE) {
- if (value == store[2]) {
- valCount += 1;
- }
- depth += 1;
- store = store[0];
- }
- Assert(store.getKind() == kind::STORE_ALL);
- ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
- Node defaultValue = storeAll.getValue();
- NodeManager* nm = NodeManager::currentNM();
-
- // Check if we are writing to default value - if so the store
- // to index can be ignored
- if (value == defaultValue) {
- if (replacedValue.isNull()) {
- // Quick exit - if writing to default value and nothing was
- // replaced, we can just return node[0]
- return node[0];
- }
- // else rebuild the store without the replaced write and then exit
- }
- else {
- n = nm->mkNode(kind::STORE, n, index, value);
- }
-
- // Build the rest of the store after inserting/deleting
- while (!indices.empty()) {
- n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
- indices.pop_back();
- elements.pop_back();
- }
-
- // Ready to exit if write was to the default value (see previous comment)
- if (value == defaultValue) {
- return n;
- }
-
- if (indexCard.isInfinite()) {
- return n;
- }
-
- // When index sort is finite, we have to check whether there is any value
- // that is written to more than the default value. If so, it must become
- // the new default value
-
- TNode mostFrequentValue;
- unsigned mostFrequentValueCount = 0;
- store = node[0];
- if (store.getKind() == kind::STORE) {
- mostFrequentValue = getMostFrequentValue(store);
- mostFrequentValueCount = getMostFrequentValueCount(store);
- }
-
- // Compute the most frequently written value for n
- if (valCount > mostFrequentValueCount ||
- (valCount == mostFrequentValueCount && value < mostFrequentValue)) {
- mostFrequentValue = value;
- mostFrequentValueCount = valCount;
- }
-
- // Need to make sure the default value count is larger, or the same and the default value is expression-order-less-than nextValue
- Cardinality::CardinalityComparison compare = indexCard.compare(mostFrequentValueCount + depth);
- Assert(compare != Cardinality::UNKNOWN);
- if (compare == Cardinality::GREATER ||
- (compare == Cardinality::EQUAL && (defaultValue < mostFrequentValue))) {
- return n;
- }
-
- // Bad case: have to recompute value counts and/or possibly switch out
- // default value
- store = n;
- std::unordered_set<TNode, TNodeHashFunction> indexSet;
- std::unordered_map<TNode, unsigned, TNodeHashFunction> elementsMap;
- std::unordered_map<TNode, unsigned, TNodeHashFunction>::iterator it;
- unsigned count;
- unsigned max = 0;
- TNode maxValue;
- while (store.getKind() == kind::STORE) {
- indices.push_back(store[1]);
- indexSet.insert(store[1]);
- elements.push_back(store[2]);
- it = elementsMap.find(store[2]);
- if (it != elementsMap.end()) {
- (*it).second = (*it).second + 1;
- count = (*it).second;
- }
- else {
- elementsMap[store[2]] = 1;
- count = 1;
- }
- if (count > max ||
- (count == max && store[2] < maxValue)) {
- max = count;
- maxValue = store[2];
- }
- store = store[0];
- }
-
- Assert(depth == indices.size());
- compare = indexCard.compare(max + depth);
- Assert(compare != Cardinality::UNKNOWN);
- if (compare == Cardinality::GREATER ||
- (compare == Cardinality::EQUAL && (defaultValue < maxValue))) {
- Assert(!replacedValue.isNull() && mostFrequentValue == replacedValue);
- return n;
- }
-
- // Out of luck: have to swap out default value
-
- // Enumerate values from index type into newIndices and sort
- std::vector<Node> newIndices;
- TypeEnumerator te(index.getType());
- bool needToSort = false;
- unsigned numTe = 0;
- while (!te.isFinished() && (!indexCard.isFinite() || numTe<indexCard.getFiniteCardinality().toUnsignedInt())) {
- if (indexSet.find(*te) == indexSet.end()) {
- if (!newIndices.empty() && (!(newIndices.back() < (*te)))) {
- needToSort = true;
- }
- newIndices.push_back(*te);
- }
- ++numTe;
- ++te;
- }
- Assert(indexCard.compare(newIndices.size() + depth) == Cardinality::EQUAL);
- if (needToSort) {
- std::sort(newIndices.begin(), newIndices.end());
- }
-
- n = nm->mkConst(ArrayStoreAll(node.getType(), maxValue));
- std::vector<Node>::iterator itNew = newIndices.begin(), it_end = newIndices.end();
- while (itNew != it_end || !indices.empty()) {
- if (itNew != it_end && (indices.empty() || (*itNew) < indices.back())) {
- n = nm->mkNode(kind::STORE, n, (*itNew), defaultValue);
- ++itNew;
- }
- else if (itNew == it_end || indices.back() < (*itNew)) {
- if (elements.back() != maxValue) {
- n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
- }
- indices.pop_back();
- elements.pop_back();
- }
- }
- return n;
- }
-
- public:
- RewriteResponse postRewrite(TNode node) override
- {
- Trace("arrays-postrewrite") << "Arrays::postRewrite start " << node << std::endl;
- switch (node.getKind()) {
- case kind::SELECT: {
- TNode store = node[0];
- TNode index = node[1];
- Node n;
- bool val;
- while (store.getKind() == kind::STORE) {
- if (index == store[1]) {
- val = true;
- }
- else if (index.isConst() && store[1].isConst()) {
- val = false;
- }
- else {
- n = Rewriter::rewrite(mkEqNode(store[1], index));
- if (n.getKind() != kind::CONST_BOOLEAN) {
- break;
- }
- val = n.getConst<bool>();
- }
- if (val) {
- // select(store(a,i,v),j) = v if i = j
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << store[2] << std::endl;
- return RewriteResponse(REWRITE_DONE, store[2]);
- }
- // select(store(a,i,v),j) = select(a,j) if i /= j
- store = store[0];
- }
- if (store.getKind() == kind::STORE_ALL) {
- // select(store_all(v),i) = v
- ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
- n = storeAll.getValue();
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
- Assert(n.isConst());
- return RewriteResponse(REWRITE_DONE, n);
- }
- else if (store != node[0]) {
- n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
- return RewriteResponse(REWRITE_DONE, n);
- }
- break;
- }
- case kind::STORE: {
- TNode store = node[0];
- TNode value = node[2];
- // store(a,i,select(a,i)) = a
- if (value.getKind() == kind::SELECT &&
- value[0] == store &&
- value[1] == node[1]) {
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << store << std::endl;
- return RewriteResponse(REWRITE_DONE, store);
- }
- TNode index = node[1];
- if (store.isConst() && index.isConst() && value.isConst()) {
- // normalize constant
- Node n = normalizeConstant(node);
- Assert(n.isConst());
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
- return RewriteResponse(REWRITE_DONE, n);
- }
- if (store.getKind() == kind::STORE) {
- // store(store(a,i,v),j,w)
- bool val;
- if (index == store[1]) {
- val = true;
- }
- else if (index.isConst() && store[1].isConst()) {
- val = false;
- }
- else {
- Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index));
- if (eqRewritten.getKind() != kind::CONST_BOOLEAN) {
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << node << std::endl;
- return RewriteResponse(REWRITE_DONE, node);
- }
- val = eqRewritten.getConst<bool>();
- }
- NodeManager* nm = NodeManager::currentNM();
- if (val) {
- // store(store(a,i,v),i,w) = store(a,i,w)
- Node result = nm->mkNode(kind::STORE, store[0], index, value);
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << result << std::endl;
- return RewriteResponse(REWRITE_AGAIN, result);
- }
- else if (index < store[1]) {
- // store(store(a,i,v),j,w) = store(store(a,j,w),i,v)
- // IF i != j and j comes before i in the ordering
- std::vector<TNode> indices;
- std::vector<TNode> elements;
- indices.push_back(store[1]);
- elements.push_back(store[2]);
- store = store[0];
- Node n;
- while (store.getKind() == kind::STORE) {
- if (index == store[1]) {
- val = true;
- }
- else if (index.isConst() && store[1].isConst()) {
- val = false;
- }
- else {
- n = Rewriter::rewrite(mkEqNode(store[1], index));
- if (n.getKind() != kind::CONST_BOOLEAN) {
- break;
- }
- val = n.getConst<bool>();
- }
- if (val) {
- store = store[0];
- break;
- }
- else if (!(index < store[1])) {
- break;
- }
- indices.push_back(store[1]);
- elements.push_back(store[2]);
- store = store[0];
- }
- if (value.getKind() == kind::SELECT &&
- value[0] == store &&
- value[1] == index) {
- n = store;
- }
- else {
- n = nm->mkNode(kind::STORE, store, index, value);
- }
- while (!indices.empty()) {
- n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
- indices.pop_back();
- elements.pop_back();
- }
- Assert(n != node);
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
- return RewriteResponse(REWRITE_AGAIN, n);
- }
- }
- break;
- }
- case kind::EQUAL:{
- if(node[0] == node[1]) {
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning true" << std::endl;
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
- }
- else if (node[0].isConst() && node[1].isConst()) {
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning false" << std::endl;
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(false));
- }
- if (node[0] > node[1]) {
- Node newNode = NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]);
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << newNode << std::endl;
- return RewriteResponse(REWRITE_DONE, newNode);
- }
- break;
- }
- default:
- break;
- }
- Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << node << std::endl;
- return RewriteResponse(REWRITE_DONE, node);
- }
-
- RewriteResponse preRewrite(TNode node) override
- {
- Trace("arrays-prerewrite") << "Arrays::preRewrite start " << node << std::endl;
- switch (node.getKind()) {
- case kind::SELECT: {
- TNode store = node[0];
- TNode index = node[1];
- Node n;
- bool val;
- while (store.getKind() == kind::STORE) {
- if (index == store[1]) {
- val = true;
- }
- else if (index.isConst() && store[1].isConst()) {
- val = false;
- }
- else {
- n = Rewriter::rewrite(mkEqNode(store[1], index));
- if (n.getKind() != kind::CONST_BOOLEAN) {
- break;
- }
- val = n.getConst<bool>();
- }
- if (val) {
- // select(store(a,i,v),j) = v if i = j
- Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store[2] << std::endl;
- return RewriteResponse(REWRITE_AGAIN, store[2]);
- }
- // select(store(a,i,v),j) = select(a,j) if i /= j
- store = store[0];
- }
- if (store.getKind() == kind::STORE_ALL) {
- // select(store_all(v),i) = v
- ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
- n = storeAll.getValue();
- Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl;
- Assert(n.isConst());
- return RewriteResponse(REWRITE_DONE, n);
- }
- else if (store != node[0]) {
- n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
- Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl;
- return RewriteResponse(REWRITE_DONE, n);
- }
- break;
- }
- case kind::STORE: {
- TNode store = node[0];
- TNode value = node[2];
- // store(a,i,select(a,i)) = a
- if (value.getKind() == kind::SELECT &&
- value[0] == store &&
- value[1] == node[1]) {
- Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store << std::endl;
- return RewriteResponse(REWRITE_AGAIN, store);
- }
- if (store.getKind() == kind::STORE) {
- // store(store(a,i,v),j,w)
- TNode index = node[1];
- bool val;
- if (index == store[1]) {
- val = true;
- }
- else if (index.isConst() && store[1].isConst()) {
- val = false;
- }
- else {
- Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index));
- if (eqRewritten.getKind() != kind::CONST_BOOLEAN) {
- break;
- }
- val = eqRewritten.getConst<bool>();
- }
- NodeManager* nm = NodeManager::currentNM();
- if (val) {
- // store(store(a,i,v),i,w) = store(a,i,w)
- Node newNode = nm->mkNode(kind::STORE, store[0], index, value);
- Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << newNode << std::endl;
- return RewriteResponse(REWRITE_DONE, newNode);
- }
- }
- break;
- }
- case kind::EQUAL:{
- if(node[0] == node[1]) {
- Trace("arrays-prerewrite") << "Arrays::preRewrite returning true" << std::endl;
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
- }
- break;
- }
- default:
- break;
- }
-
- Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << node << std::endl;
- return RewriteResponse(REWRITE_DONE, node);
- }
+ TrustNode expandDefinition(Node node) override;
static inline void init() {}
static inline void shutdown() {}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback