diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2018-09-24 21:28:31 -0500 |
---|---|---|
committer | Andres Noetzli <andres.noetzli@gmail.com> | 2018-09-24 19:28:31 -0700 |
commit | 510788587866b16d9ba49bb36a492278ac5fd144 (patch) | |
tree | 99b77a465e235d1a8790d0f2eb48cd49b401b0cc | |
parent | 86ec94126885dc0756abe2e00ecbe71288c13410 (diff) |
Refactor strings equality rewriting (#2513)
This moves the extended rewrites for string equality to the main strings rewriter as a function rewriteEqualityExt, and makes this function called on every equality that is generated (from non-equalities) by our rewriter.
-rw-r--r-- | src/theory/quantifiers/extended_rewrite.cpp | 202 | ||||
-rw-r--r-- | src/theory/strings/theory_strings_rewriter.cpp | 229 | ||||
-rw-r--r-- | src/theory/strings/theory_strings_rewriter.h | 25 |
3 files changed, 224 insertions, 232 deletions
diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index df82e0750..e64e1b7b2 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -1671,208 +1671,6 @@ Node ExtendedRewriter::extendedRewriteStrings(Node ret) Node new_ret; Trace("q-ext-rewrite-debug") << "Extended rewrite strings : " << ret << std::endl; - NodeManager* nm = NodeManager::currentNM(); - if (ret.getKind() == EQUAL) - { - if (ret[0].getType().isString()) - { - std::vector<Node> c[2]; - for (unsigned i = 0; i < 2; i++) - { - strings::TheoryStringsRewriter::getConcat(ret[i], c[i]); - } - - // ------- equality unification - bool changed = false; - for (unsigned i = 0; i < 2; i++) - { - while (!c[0].empty() && !c[1].empty() && c[0].back() == c[1].back()) - { - c[0].pop_back(); - c[1].pop_back(); - changed = true; - } - // splice constants - if (!c[0].empty() && !c[1].empty() && c[0].back().isConst() - && c[1].back().isConst()) - { - String cs[2]; - for (unsigned j = 0; j < 2; j++) - { - cs[j] = c[j].back().getConst<String>(); - } - unsigned larger = cs[0].size() > cs[1].size() ? 0 : 1; - unsigned smallerSize = cs[1 - larger].size(); - if (cs[1 - larger] - == (i == 0 ? cs[larger].suffix(smallerSize) - : cs[larger].prefix(smallerSize))) - { - unsigned sizeDiff = cs[larger].size() - smallerSize; - c[larger][c[larger].size() - 1] = - nm->mkConst(i == 0 ? cs[larger].prefix(sizeDiff) - : cs[larger].suffix(sizeDiff)); - c[1 - larger].pop_back(); - changed = true; - } - } - for (unsigned j = 0; j < 2; j++) - { - std::reverse(c[j].begin(), c[j].end()); - } - } - if (changed) - { - // e.g. x++y = x++z ---> y = z, "AB" ++ x = "A" ++ y --> "B" ++ x = y - Node s1 = strings::TheoryStringsRewriter::mkConcat(STRING_CONCAT, c[0]); - Node s2 = strings::TheoryStringsRewriter::mkConcat(STRING_CONCAT, c[1]); - new_ret = s1.eqNode(s2); - debugExtendedRewrite(ret, new_ret, "string-eq-unify"); - return new_ret; - } - - // ------- using the contains rewriter to reduce equalities - Node tcontains[2]; - bool tcontainsOneTrue = false; - unsigned tcontainsTrueIndex = 0; - for (unsigned i = 0; i < 2; i++) - { - Node tc = nm->mkNode(STRING_STRCTN, ret[i], ret[1 - i]); - tcontains[i] = Rewriter::rewrite(tc); - if (tcontains[i].isConst()) - { - if (tcontains[i].getConst<bool>()) - { - tcontainsOneTrue = true; - tcontainsTrueIndex = i; - } - else - { - new_ret = tcontains[i]; - // if str.contains( x, y ) ---> false then x = y ---> false - // Notice we may not catch this in the rewriter for strings - // equality, since it only calls the specific rewriter for - // contains and not the full rewriter. - debugExtendedRewrite(ret, new_ret, "eq-contains-one-false"); - return new_ret; - } - } - } - if (tcontainsOneTrue) - { - // if str.contains( x, y ) ---> true - // then x = y ---> contains( y, x ) - new_ret = tcontains[1 - tcontainsTrueIndex]; - debugExtendedRewrite(ret, new_ret, "eq-contains-one-true"); - return new_ret; - } - else if (tcontains[0] == tcontains[1] && tcontains[0] != ret) - { - // if str.contains( x, y ) ---> t and str.contains( y, x ) ---> t, - // then x = y ---> t - new_ret = tcontains[0]; - debugExtendedRewrite(ret, new_ret, "eq-dual-contains-eq"); - return new_ret; - } - - // ------- homogeneous constants - if (d_aggr) - { - for (unsigned i = 0; i < 2; i++) - { - if (ret[i].isConst()) - { - bool isHomogeneous = true; - unsigned hchar = 0; - String lhss = ret[i].getConst<String>(); - std::vector<unsigned> vec = lhss.getVec(); - if (vec.size() > 1) - { - hchar = vec[0]; - for (unsigned j = 1, size = vec.size(); j < size; j++) - { - if (vec[j] != hchar) - { - isHomogeneous = false; - break; - } - } - } - if (isHomogeneous) - { - std::sort(c[1 - i].begin(), c[1 - i].end()); - std::vector<Node> trimmed; - unsigned rmChar = 0; - for (unsigned j = 0, size = c[1 - i].size(); j < size; j++) - { - if (c[1 - i][j].isConst()) - { - // process the constant : either we have a conflict, or we - // drop an equal number of constants on the LHS - std::vector<unsigned> vecj = - c[1 - i][j].getConst<String>().getVec(); - for (unsigned k = 0, sizev = vecj.size(); k < sizev; k++) - { - bool conflict = false; - if (vec.empty()) - { - // e.g. "" = x ++ "A" ---> false - conflict = true; - } - else if (vecj[k] != hchar) - { - // e.g. "AA" = x ++ "B" ---> false - conflict = true; - } - else - { - rmChar++; - if (rmChar > lhss.size()) - { - // e.g. "AA" = x ++ "AAA" ---> false - conflict = true; - } - } - if (conflict) - { - // The three conflict cases should mostly should be taken - // care of by multiset reasoning in the strings rewriter, - // but we recognize this conflict just in case. - new_ret = nm->mkConst(false); - debugExtendedRewrite( - ret, new_ret, "string-eq-const-conflict"); - return new_ret; - } - } - } - else - { - trimmed.push_back(c[1 - i][j]); - } - } - Node lhs = ret[i]; - if (rmChar > 0) - { - Assert(lhss.size() >= rmChar); - // we trimmed - lhs = nm->mkConst(lhss.substr(0, lhss.size() - rmChar)); - } - Node ss = strings::TheoryStringsRewriter::mkConcat(STRING_CONCAT, - trimmed); - if (lhs != ret[i] || ss != ret[1 - i]) - { - // e.g. - // "AA" = y ++ x ---> "AA" = x ++ y if x < y - // "AAA" = y ++ "A" ++ z ---> "AA" = y ++ z - new_ret = lhs.eqNode(ss); - debugExtendedRewrite(ret, new_ret, "string-eq-homog-const"); - return new_ret; - } - } - } - } - } - } - } return new_ret; } diff --git a/src/theory/strings/theory_strings_rewriter.cpp b/src/theory/strings/theory_strings_rewriter.cpp index 48b288ea3..f8bbeecf5 100644 --- a/src/theory/strings/theory_strings_rewriter.cpp +++ b/src/theory/strings/theory_strings_rewriter.cpp @@ -315,10 +315,179 @@ Node TheoryStringsRewriter::rewriteEquality(Node node) { return NodeManager::currentNM()->mkNode(kind::EQUAL, node[1], node[0]); } - else + return node; +} + +Node TheoryStringsRewriter::rewriteEqualityExt(Node node) +{ + Assert(node.getKind() == EQUAL); + if (!node[0].getType().isString()) { return node; } + NodeManager* nm = NodeManager::currentNM(); + std::vector<Node> c[2]; + Node new_ret; + for (unsigned i = 0; i < 2; i++) + { + getConcat(node[i], c[i]); + } + // ------- equality unification + bool changed = false; + for (unsigned i = 0; i < 2; i++) + { + while (!c[0].empty() && !c[1].empty() && c[0].back() == c[1].back()) + { + c[0].pop_back(); + c[1].pop_back(); + changed = true; + } + // splice constants + if (!c[0].empty() && !c[1].empty() && c[0].back().isConst() + && c[1].back().isConst()) + { + String cs[2]; + for (unsigned j = 0; j < 2; j++) + { + cs[j] = c[j].back().getConst<String>(); + } + unsigned larger = cs[0].size() > cs[1].size() ? 0 : 1; + unsigned smallerSize = cs[1 - larger].size(); + if (cs[1 - larger] + == (i == 0 ? cs[larger].suffix(smallerSize) + : cs[larger].prefix(smallerSize))) + { + unsigned sizeDiff = cs[larger].size() - smallerSize; + c[larger][c[larger].size() - 1] = nm->mkConst( + i == 0 ? cs[larger].prefix(sizeDiff) : cs[larger].suffix(sizeDiff)); + c[1 - larger].pop_back(); + changed = true; + } + } + for (unsigned j = 0; j < 2; j++) + { + std::reverse(c[j].begin(), c[j].end()); + } + } + if (changed) + { + // e.g. x++y = x++z ---> y = z, "AB" ++ x = "A" ++ y --> "B" ++ x = y + Node s1 = mkConcat(STRING_CONCAT, c[0]); + Node s2 = mkConcat(STRING_CONCAT, c[1]); + new_ret = s1.eqNode(s2); + node = returnRewrite(node, new_ret, "str-eq-unify"); + } + + // ------- homogeneous constants + for (unsigned i = 0; i < 2; i++) + { + if (node[i].isConst()) + { + bool isHomogeneous = true; + unsigned hchar = 0; + String lhss = node[i].getConst<String>(); + std::vector<unsigned> vec = lhss.getVec(); + if (vec.size() > 1) + { + hchar = vec[0]; + for (unsigned j = 1, size = vec.size(); j < size; j++) + { + if (vec[j] != hchar) + { + isHomogeneous = false; + break; + } + } + } + if (isHomogeneous) + { + std::sort(c[1 - i].begin(), c[1 - i].end()); + std::vector<Node> trimmed; + unsigned rmChar = 0; + for (unsigned j = 0, size = c[1 - i].size(); j < size; j++) + { + if (c[1 - i][j].isConst()) + { + // process the constant : either we have a conflict, or we + // drop an equal number of constants on the LHS + std::vector<unsigned> vecj = + c[1 - i][j].getConst<String>().getVec(); + for (unsigned k = 0, sizev = vecj.size(); k < sizev; k++) + { + bool conflict = false; + if (vec.empty()) + { + // e.g. "" = x ++ "A" ---> false + conflict = true; + } + else if (vecj[k] != hchar) + { + // e.g. "AA" = x ++ "B" ---> false + conflict = true; + } + else + { + rmChar++; + if (rmChar > lhss.size()) + { + // e.g. "AA" = x ++ "AAA" ---> false + conflict = true; + } + } + if (conflict) + { + // The three conflict cases should mostly should be taken + // care of by multiset reasoning in the strings rewriter, + // but we recognize this conflict just in case. + new_ret = nm->mkConst(false); + return returnRewrite(node, new_ret, "string-eq-const-conflict"); + } + } + } + else + { + trimmed.push_back(c[1 - i][j]); + } + } + Node lhs = node[i]; + if (rmChar > 0) + { + Assert(lhss.size() >= rmChar); + // we trimmed + lhs = nm->mkConst(lhss.substr(0, lhss.size() - rmChar)); + } + Node ss = mkConcat(STRING_CONCAT, trimmed); + if (lhs != node[i] || ss != node[1 - i]) + { + // e.g. + // "AA" = y ++ x ---> "AA" = x ++ y if x < y + // "AAA" = y ++ "A" ++ z ---> "AA" = y ++ z + new_ret = lhs.eqNode(ss); + node = returnRewrite(node, new_ret, "str-eq-homog-const"); + } + } + } + } + + Assert(node.getKind() == EQUAL); + + // Try to rewrite (= x y) into a conjunction of equalities based on length + // entailment. + // + // (<= (str.len x) (str.++ y1 ... yn)) AND (= x (str.++ y1 ... yn)) ---> + // (and (= x (str.++ y1' ... ym')) (= y1'' "") ... (= yk'' "")) + // + // where yi' and yi'' correspond to some yj and + // (<= (str.len x) (str.++ y1' ... ym')) + for (unsigned i = 0; i < 2; i++) + { + new_ret = inferEqsFromContains(node[i], node[1 - i]); + if (!new_ret.isNull()) + { + return returnRewrite(node, new_ret, "str-eq-conj-len-entail"); + } + } + return node; } // TODO (#1180) add rewrite @@ -1710,22 +1879,11 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) { // TODO (#1180): abstract interpretation with multi-set domain // to show first argument is a strict subset of second argument - // Try to rewrite (str.contains x y) into an equality or a conjunction of - // equalities: - // - // (str.contains x y) ---> (= x y) if (<= (str.len x) (str.len y)) - // - // or more generally: - // - // (str.contains x (str.++ y1 ... yn)) ---> - // (and (= x (str.++ y1' ... ym')) (= y1'' "") ... (= yk'' "")) - // - // where yi' and yi'' correspond to some yj and - // (<= (str.len x) (str.++ y1' ... ym')) - Node eqs = inferEqsFromContains(node[0], node[1]); - if (!eqs.isNull()) + if (checkEntailArith(len_n2, len_n1, false)) { - return returnRewrite(node, eqs, "ctn-to-eqs"); + // len( n2 ) >= len( n1 ) => contains( n1, n2 ) ---> n1 = n2 + Node ret = node[0].eqNode(node[1]); + return returnRewrite(node, ret, "ctn-len-ineq-nstrict"); } // splitting @@ -2574,14 +2732,7 @@ Node TheoryStringsRewriter::rewritePrefixSuffix(Node n) // general reduction to equality + substr Node retNode = n[0].eqNode( NodeManager::currentNM()->mkNode(kind::STRING_SUBSTR, n[1], val, lens)); - // add length constraint if it cannot be shown by simple entailment check - if (!checkEntailArith(lent, lens)) - { - retNode = NodeManager::currentNM()->mkNode( - kind::AND, - retNode, - NodeManager::currentNM()->mkNode(kind::GEQ, lent, lens)); - } + return retNode; } @@ -3913,5 +4064,35 @@ Node TheoryStringsRewriter::returnRewrite(Node node, Node ret, const char* c) { Trace("strings-rewrite") << "Rewrite " << node << " to " << ret << " by " << c << "." << std::endl; + // standard post-processing + // We rewrite (string) equalities immediately here. This allows us to forego + // the standard invariant on equality rewrites (that s=t must rewrite to one + // of { s=t, t=s, true, false } ). + Kind retk = ret.getKind(); + if (retk == OR || retk == AND) + { + std::vector<Node> children; + bool childChanged = false; + for (const Node& cret : ret) + { + Node creter = cret; + if (cret.getKind() == EQUAL) + { + creter = rewriteEqualityExt(cret); + } + childChanged = childChanged || cret != creter; + children.push_back(creter); + } + if (childChanged) + { + ret = NodeManager::currentNM()->mkNode(retk, children); + } + } + else if (retk == EQUAL && node.getKind() != EQUAL) + { + Trace("strings-rewrite") + << "Apply extended equality rewrite on " << ret << std::endl; + ret = rewriteEqualityExt(ret); + } return ret; } diff --git a/src/theory/strings/theory_strings_rewriter.h b/src/theory/strings/theory_strings_rewriter.h index 70c573d9e..c0aa91360 100644 --- a/src/theory/strings/theory_strings_rewriter.h +++ b/src/theory/strings/theory_strings_rewriter.h @@ -98,12 +98,16 @@ class TheoryStringsRewriter { * a is in rewritten form. */ static bool checkEntailArithInternal(Node a); - /** return rewrite + /** * Called when node rewrites to ret. - * The string c indicates the justification - * for the rewrite, which is printed by this - * function for debugging. - * This function returns ret. + * + * The string c indicates the justification for the rewrite, which is printed + * by this function for debugging. + * + * If node is not an equality and ret is an equality, this method applies + * an additional rewrite step (rewriteEqualityExt) that performs + * additional rewrites on ret, after which we return the result of this call. + * Otherwise, this method simply returns ret. */ static Node returnRewrite(Node node, Node ret, const char* c); @@ -118,9 +122,18 @@ class TheoryStringsRewriter { /** rewrite equality * * This method returns a formula that is equivalent to the equality between - * two strings, given by node. + * two strings s = t, given by node. The result of rewrite is one of + * { s = t, t = s, true, false }. */ static Node rewriteEquality(Node node); + /** rewrite equality extended + * + * This method returns a formula that is equivalent to the equality between + * two strings s = t, given by node. Specifically, this function performs + * rewrites whose conclusion is not necessarily one of + * { s = t, t = s, true, false }. + */ + static Node rewriteEqualityExt(Node node); /** rewrite concat * This is the entry point for post-rewriting terms node of the form * str.++( t1, .., tn ) |