diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2018-02-12 18:16:59 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-12 18:16:59 -0600 |
commit | 1b0aa1c39ff7abe15bbd9305d376d10b007d69d0 (patch) | |
tree | 987d3afa7231f9a21b22bd20afcdb62ed5c87743 /src/theory/quantifiers/extended_rewrite.cpp | |
parent | 04114df7dd58bd7391704a94fe98e2935b39130d (diff) |
Option to use extended rewriter as a preprocessing pass (#1600)
Diffstat (limited to 'src/theory/quantifiers/extended_rewrite.cpp')
-rw-r--r-- | src/theory/quantifiers/extended_rewrite.cpp | 315 |
1 files changed, 155 insertions, 160 deletions
diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index ba0860d38..dd4fc86ba 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -26,7 +26,7 @@ namespace CVC4 { namespace theory { namespace quantifiers { -ExtendedRewriter::ExtendedRewriter() +ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr) { d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); @@ -76,8 +76,7 @@ Node ExtendedRewriter::extendedRewritePullIte(Node n) Trace("q-ext-rewrite") << "sygus-extr : " << n << " rewrites to " << nc << " by simple ITE pulling." << std::endl; - // recurse - return extendedRewrite(nc); + return nc; } } } @@ -87,205 +86,201 @@ Node ExtendedRewriter::extendedRewritePullIte(Node n) Node ExtendedRewriter::extendedRewrite(Node n) { + n = Rewriter::rewrite(n); std::unordered_map<Node, Node, NodeHashFunction>::iterator it = d_ext_rewrite_cache.find(n); - if (it == d_ext_rewrite_cache.end()) + if (it != d_ext_rewrite_cache.end()) { - Node ret = n; - if (n.getNumChildren() > 0) + return it->second; + } + Node ret = n; + if (n.getNumChildren() > 0) + { + std::vector<Node> children; + if (n.getMetaKind() == kind::metakind::PARAMETERIZED) { - std::vector<Node> children; - if (n.getMetaKind() == kind::metakind::PARAMETERIZED) - { - children.push_back(n.getOperator()); - } - bool childChanged = false; - for (unsigned i = 0; i < n.getNumChildren(); i++) - { - Node nc = extendedRewrite(n[i]); - childChanged = nc != n[i] || childChanged; - children.push_back(nc); - } - // Some commutative operators have rewriters that are agnostic to order, - // thus, we sort here. - if (TermUtil::isComm(n.getKind())) - { - childChanged = true; - std::sort(children.begin(), children.end()); - } - if (childChanged) - { - ret = NodeManager::currentNM()->mkNode(n.getKind(), children); - } + children.push_back(n.getOperator()); } - ret = Rewriter::rewrite(ret); - Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret - << " (from " << n << ")" << std::endl; + bool childChanged = false; + for (unsigned i = 0; i < n.getNumChildren(); i++) + { + Node nc = extendedRewrite(n[i]); + childChanged = nc != n[i] || childChanged; + children.push_back(nc); + } + // Some commutative operators have rewriters that are agnostic to order, + // thus, we sort here. + if (TermUtil::isComm(n.getKind()) && (d_aggr || children.size() <= 5)) + { + childChanged = true; + std::sort(children.begin(), children.end()); + } + if (childChanged) + { + ret = NodeManager::currentNM()->mkNode(n.getKind(), children); + } + } + ret = Rewriter::rewrite(ret); + Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret + << " (from " << n << ")" << std::endl; - Node new_ret; - if (ret.getKind() == kind::EQUAL) + Node new_ret; + if (ret.getKind() == kind::EQUAL) + { + if (new_ret.isNull()) { - if (new_ret.isNull()) - { - // simple ITE pulling - new_ret = extendedRewritePullIte(ret); - } + // simple ITE pulling + new_ret = extendedRewritePullIte(ret); + } + } + else if (ret.getKind() == kind::ITE) + { + Assert(ret[1] != ret[2]); + if (ret[0].getKind() == NOT) + { + ret = NodeManager::currentNM()->mkNode( + kind::ITE, ret[0][0], ret[2], ret[1]); } - else if (ret.getKind() == kind::ITE) + if (ret[0].getKind() == kind::EQUAL) { - Assert(ret[1] != ret[2]); - if (ret[0].getKind() == NOT) + // simple invariant ITE + for (unsigned i = 0; i < 2; i++) { - ret = NodeManager::currentNM()->mkNode( - kind::ITE, ret[0][0], ret[2], ret[1]); + if (ret[1] == ret[0][i] && ret[2] == ret[0][1 - i]) + { + Trace("q-ext-rewrite") + << "sygus-extr : " << ret << " rewrites to " << ret[2] + << " due to simple invariant ITE." << std::endl; + new_ret = ret[2]; + break; + } } - if (ret[0].getKind() == kind::EQUAL) + // notice this is strictly more general than the above + if (new_ret.isNull()) { - // simple invariant ITE + // simple substitution for (unsigned i = 0; i < 2; i++) { - if (ret[1] == ret[0][i] && ret[2] == ret[0][1 - i]) + TNode r1 = ret[0][i]; + TNode r2 = ret[0][1 - i]; + if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst())) { - Trace("q-ext-rewrite") << "sygus-extr : " << ret << " rewrites to " - << ret[2] << " due to simple invariant ITE." - << std::endl; - new_ret = ret[2]; - break; - } - } - // notice this is strictly more general than the above - if (new_ret.isNull()) - { - // simple substitution - for (unsigned i = 0; i < 2; i++) - { - TNode r1 = ret[0][i]; - TNode r2 = ret[0][1 - i]; - if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst())) + Node retn = ret[1].substitute(r1, r2); + if (retn != ret[1]) { - Node retn = ret[1].substitute(r1, r2); - if (retn != ret[1]) - { - new_ret = NodeManager::currentNM()->mkNode( - kind::ITE, ret[0], retn, ret[2]); - Trace("q-ext-rewrite") - << "sygus-extr : " << ret << " rewrites to " << new_ret - << " due to simple ITE substitution." << std::endl; - } + new_ret = NodeManager::currentNM()->mkNode( + kind::ITE, ret[0], retn, ret[2]); + Trace("q-ext-rewrite") + << "sygus-extr : " << ret << " rewrites to " << new_ret + << " due to simple ITE substitution." << std::endl; } } } } } - else if (ret.getKind() == DIVISION || ret.getKind() == INTS_DIVISION - || ret.getKind() == INTS_MODULUS) + } + else if (ret.getKind() == DIVISION || ret.getKind() == INTS_DIVISION + || ret.getKind() == INTS_MODULUS) + { + // rewrite as though total + std::vector<Node> children; + bool all_const = true; + for (unsigned i = 0; i < ret.getNumChildren(); i++) { - // rewrite as though total - std::vector<Node> children; - bool all_const = true; - for (unsigned i = 0; i < ret.getNumChildren(); i++) + if (ret[i].isConst()) { - if (ret[i].isConst()) - { - children.push_back(ret[i]); - } - else - { - all_const = false; - break; - } + children.push_back(ret[i]); } - if (all_const) + else { - Kind new_k = - (ret.getKind() == DIVISION - ? DIVISION_TOTAL - : (ret.getKind() == INTS_DIVISION ? INTS_DIVISION_TOTAL - : INTS_MODULUS_TOTAL)); - new_ret = NodeManager::currentNM()->mkNode(new_k, children); - Trace("q-ext-rewrite") << "sygus-extr : " << ret << " rewrites to " - << new_ret << " due to total interpretation." - << std::endl; + all_const = false; + break; } } - // more expensive rewrites - if (new_ret.isNull()) + if (all_const) { - Trace("q-ext-rewrite-debug2") << "Do expensive rewrites on " << ret - << std::endl; - bool polarity = ret.getKind() != NOT; - Node ret_atom = ret.getKind() == NOT ? ret[0] : ret; - if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal()) - || ret_atom.getKind() == GEQ) + Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL + : (ret.getKind() == INTS_DIVISION + ? INTS_DIVISION_TOTAL + : INTS_MODULUS_TOTAL)); + new_ret = NodeManager::currentNM()->mkNode(new_k, children); + Trace("q-ext-rewrite") + << "sygus-extr : " << ret << " rewrites to " << new_ret + << " due to total interpretation." << std::endl; + } + } + // more expensive rewrites + if (new_ret.isNull() && d_aggr) + { + new_ret = extendedRewriteAggr(ret); + } + + d_ext_rewrite_cache[n] = ret; + if (!new_ret.isNull()) + { + ret = extendedRewrite(new_ret); + } + d_ext_rewrite_cache[n] = ret; + return ret; +} + +Node ExtendedRewriter::extendedRewriteAggr(Node n) +{ + Node new_ret; + Trace("q-ext-rewrite-debug2") + << "Do aggressive rewrites on " << n << std::endl; + bool polarity = n.getKind() != NOT; + Node ret_atom = n.getKind() == NOT ? n[0] : n; + if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal()) + || ret_atom.getKind() == GEQ) + { + Trace("q-ext-rewrite-debug2") + << "Compute monomial sum " << ret_atom << std::endl; + // compute monomial sum + std::map<Node, Node> msum; + if (ArithMSum::getMonomialSumLit(ret_atom, msum)) + { + for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end(); + ++itm) { - Trace("q-ext-rewrite-debug2") << "Compute monomial sum " << ret_atom - << std::endl; - // compute monomial sum - std::map<Node, Node> msum; - if (ArithMSum::getMonomialSumLit(ret_atom, msum)) + Node v = itm->first; + Trace("q-ext-rewrite-debug2") + << itm->first << " * " << itm->second << std::endl; + if (v.getKind() == ITE) { - for (std::map<Node, Node>::iterator itm = msum.begin(); - itm != msum.end(); - ++itm) + Node veq; + int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind()); + if (res != 0) { - Node v = itm->first; - Trace("q-ext-rewrite-debug2") << itm->first << " * " << itm->second - << std::endl; - if (v.getKind() == ITE) + Trace("q-ext-rewrite-debug") + << " have ITE relation, solved form : " << veq << std::endl; + // try pulling ITE + new_ret = extendedRewritePullIte(veq); + if (!new_ret.isNull()) { - Node veq; - int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind()); - if (res != 0) + if (!polarity) { - Trace("q-ext-rewrite-debug") - << " have ITE relation, solved form : " << veq - << std::endl; - // try pulling ITE - new_ret = extendedRewritePullIte(veq); - if (!new_ret.isNull()) - { - if (!polarity) - { - new_ret = new_ret.negate(); - } - break; - } - } - else - { - Trace("q-ext-rewrite-debug") << " failed to isolate " << v - << " in " << ret << std::endl; + new_ret = new_ret.negate(); } + break; } } - } - else - { - Trace("q-ext-rewrite-debug") << " failed to get monomial sum of " - << ret << std::endl; + else + { + Trace("q-ext-rewrite-debug") + << " failed to isolate " << v << " in " << n << std::endl; + } } } - else if (ret_atom.getKind() == ITE) - { - // TODO : conditional rewriting - } - else if (ret.getKind() == kind::AND || ret.getKind() == kind::OR) - { - // TODO condition merging - } } - - if (!new_ret.isNull()) + else { - ret = Rewriter::rewrite(new_ret); + Trace("q-ext-rewrite-debug") + << " failed to get monomial sum of " << n << std::endl; } - d_ext_rewrite_cache[n] = ret; - return ret; - } - else - { - return it->second; } + // TODO (#1599) : conditional rewriting, condition merging + return new_ret; } } /* CVC4::theory::quantifiers namespace */ |