summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/options/quantifiers_options.toml8
-rw-r--r--src/theory/quantifiers/extended_rewrite.cpp1030
-rw-r--r--src/theory/quantifiers/extended_rewrite.h143
-rw-r--r--src/theory/quantifiers/term_util.cpp22
-rw-r--r--src/theory/quantifiers/term_util.h12
5 files changed, 1055 insertions, 160 deletions
diff --git a/src/options/quantifiers_options.toml b/src/options/quantifiers_options.toml
index 28a9e58a7..f877143a2 100644
--- a/src/options/quantifiers_options.toml
+++ b/src/options/quantifiers_options.toml
@@ -1086,6 +1086,14 @@ header = "options/quantifiers_options.h"
help = "enumerate a stream of solutions instead of terminating after the first one"
[[option]]
+ name = "sygusExtRew"
+ category = "regular"
+ long = "sygus-ext-rew"
+ type = "bool"
+ default = "true"
+ help = "use extended rewriter for sygus"
+
+[[option]]
name = "cegisSample"
category = "regular"
long = "cegis-sample=MODE"
diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp
index dd4fc86ba..756413b54 100644
--- a/src/theory/quantifiers/extended_rewrite.cpp
+++ b/src/theory/quantifiers/extended_rewrite.cpp
@@ -14,7 +14,9 @@
#include "theory/quantifiers/extended_rewrite.h"
+#include "options/quantifiers_options.h"
#include "theory/arith/arith_msum.h"
+#include "theory/bv/theory_bv_utils.h"
#include "theory/datatypes/datatypes_rewriter.h"
#include "theory/quantifiers/term_util.h"
#include "theory/rewriter.h"
@@ -26,201 +28,176 @@ namespace CVC4 {
namespace theory {
namespace quantifiers {
+struct ExtRewriteAttributeId
+{
+};
+typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
+
ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr)
{
- d_true = NodeManager::currentNM()->mkConst(true);
- d_false = NodeManager::currentNM()->mkConst(false);
}
-
-Node ExtendedRewriter::extendedRewritePullIte(Node n)
+void ExtendedRewriter::setCache(Node n, Node ret)
{
- // generalize this?
- Assert(n.getNumChildren() == 2);
- Assert(n.getType().isBoolean());
- Assert(n.getMetaKind() != kind::metakind::PARAMETERIZED);
- std::vector<Node> children;
- for (unsigned i = 0; i < n.getNumChildren(); i++)
- {
- children.push_back(n[i]);
- }
- for (unsigned i = 0; i < 2; i++)
- {
- if (n[i].getKind() == kind::ITE)
- {
- for (unsigned j = 0; j < 2; j++)
- {
- children[i] = n[i][j + 1];
- Node eqr = extendedRewrite(
- NodeManager::currentNM()->mkNode(n.getKind(), children));
- children[i] = n[i];
- if (eqr.isConst())
- {
- std::vector<Node> new_children;
- Kind new_k;
- if (eqr == d_true)
- {
- new_k = kind::OR;
- new_children.push_back(j == 0 ? n[i][0] : n[i][0].negate());
- }
- else
- {
- Assert(eqr == d_false);
- new_k = kind::AND;
- new_children.push_back(j == 0 ? n[i][0].negate() : n[i][0]);
- }
- children[i] = n[i][2 - j];
- Node rem_eq = NodeManager::currentNM()->mkNode(n.getKind(), children);
- children[i] = n[i];
- new_children.push_back(rem_eq);
- Node nc = NodeManager::currentNM()->mkNode(new_k, new_children);
- Trace("q-ext-rewrite") << "sygus-extr : " << n << " rewrites to "
- << nc << " by simple ITE pulling."
- << std::endl;
- return nc;
- }
- }
- }
- }
- return Node::null();
+ ExtRewriteAttribute era;
+ n.setAttribute(era, ret);
}
Node ExtendedRewriter::extendedRewrite(Node n)
{
n = Rewriter::rewrite(n);
- std::unordered_map<Node, Node, NodeHashFunction>::iterator it =
- d_ext_rewrite_cache.find(n);
- if (it != d_ext_rewrite_cache.end())
+ if (!options::sygusExtRew())
+ {
+ return n;
+ }
+
+ // has it already been computed?
+ if (n.hasAttribute(ExtRewriteAttribute()))
{
- return it->second;
+ return n.getAttribute(ExtRewriteAttribute());
}
+
Node ret = n;
+ NodeManager* nm = NodeManager::currentNM();
+
+ //--------------------pre-rewrite
+ Node pre_new_ret;
+ if (ret.getKind() == IMPLIES)
+ {
+ pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
+ debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
+ }
+ else if (ret.getKind() == XOR)
+ {
+ pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
+ debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
+ }
+ else if (ret.getKind() == NOT)
+ {
+ pre_new_ret = extendedRewriteNnf(ret);
+ debugExtendedRewrite(ret, pre_new_ret, "NNF");
+ }
+ if (!pre_new_ret.isNull())
+ {
+ ret = extendedRewrite(pre_new_ret);
+ Trace("q-ext-rewrite-debug") << "...ext-pre-rewrite : " << n << " -> "
+ << pre_new_ret << std::endl;
+ setCache(n, ret);
+ return ret;
+ }
+ //--------------------end pre-rewrite
+
+ //--------------------rewrite children
if (n.getNumChildren() > 0)
{
std::vector<Node> children;
- if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
+ if (n.getMetaKind() == metakind::PARAMETERIZED)
{
children.push_back(n.getOperator());
}
+ Kind k = n.getKind();
bool childChanged = false;
+ bool isNonAdditive = TermUtil::isNonAdditive(k);
for (unsigned i = 0; i < n.getNumChildren(); i++)
{
Node nc = extendedRewrite(n[i]);
childChanged = nc != n[i] || childChanged;
- children.push_back(nc);
+ // If the operator is non-additive, do not consider duplicates
+ if (isNonAdditive
+ && std::find(children.begin(), children.end(), nc) != children.end())
+ {
+ childChanged = true;
+ }
+ else
+ {
+ children.push_back(nc);
+ }
}
+ Assert(!children.empty());
// Some commutative operators have rewriters that are agnostic to order,
// thus, we sort here.
- if (TermUtil::isComm(n.getKind()) && (d_aggr || children.size() <= 5))
+ if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
{
childChanged = true;
std::sort(children.begin(), children.end());
}
if (childChanged)
{
- ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
+ if (isNonAdditive && children.size() == 1)
+ {
+ // we may have subsumed children down to one
+ ret = children[0];
+ }
+ else
+ {
+ ret = nm->mkNode(k, children);
+ }
}
}
ret = Rewriter::rewrite(ret);
+ //--------------------end rewrite children
+
+ // now, do extended rewrite
Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
<< " (from " << n << ")" << std::endl;
-
Node new_ret;
- if (ret.getKind() == kind::EQUAL)
+
+ //---------------------- theory-independent post-rewriting
+ if (ret.getKind() == ITE)
{
- if (new_ret.isNull())
- {
- // simple ITE pulling
- new_ret = extendedRewritePullIte(ret);
- }
+ new_ret = extendedRewriteIte(ITE, ret);
}
- else if (ret.getKind() == kind::ITE)
+ else if (ret.getKind() == AND || ret.getKind() == OR)
{
- Assert(ret[1] != ret[2]);
- if (ret[0].getKind() == NOT)
- {
- ret = NodeManager::currentNM()->mkNode(
- kind::ITE, ret[0][0], ret[2], ret[1]);
- }
- if (ret[0].getKind() == kind::EQUAL)
- {
- // simple invariant ITE
- for (unsigned i = 0; i < 2; i++)
- {
- if (ret[1] == ret[0][i] && ret[2] == ret[0][1 - i])
- {
- Trace("q-ext-rewrite")
- << "sygus-extr : " << ret << " rewrites to " << ret[2]
- << " due to simple invariant ITE." << std::endl;
- new_ret = ret[2];
- break;
- }
- }
- // notice this is strictly more general than the above
- if (new_ret.isNull())
- {
- // simple substitution
- for (unsigned i = 0; i < 2; i++)
- {
- TNode r1 = ret[0][i];
- TNode r2 = ret[0][1 - i];
- if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
- {
- Node retn = ret[1].substitute(r1, r2);
- if (retn != ret[1])
- {
- new_ret = NodeManager::currentNM()->mkNode(
- kind::ITE, ret[0], retn, ret[2]);
- Trace("q-ext-rewrite")
- << "sygus-extr : " << ret << " rewrites to " << new_ret
- << " due to simple ITE substitution." << std::endl;
- }
- }
- }
- }
- }
+ // all kinds are legal to substitute over : hence we give the empty map
+ std::map<Kind, bool> bcp_kinds;
+ new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, ret);
+ debugExtendedRewrite(ret, new_ret, "Bool bcp");
}
- else if (ret.getKind() == DIVISION || ret.getKind() == INTS_DIVISION
- || ret.getKind() == INTS_MODULUS)
+ else if (ret.getKind() == EQUAL)
{
- // rewrite as though total
- std::vector<Node> children;
- bool all_const = true;
- for (unsigned i = 0; i < ret.getNumChildren(); i++)
+ new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
+ debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
+ }
+ if (new_ret.isNull() && ret.getKind() != ITE)
+ {
+ // simple ITE pulling
+ new_ret = extendedRewritePullIte(ITE, ret);
+ }
+ //----------------------end theory-independent post-rewriting
+
+ //----------------------theory-specific post-rewriting
+ if (new_ret.isNull())
+ {
+ Node atom = ret.getKind() == NOT ? ret[0] : ret;
+ bool pol = ret.getKind() != NOT;
+ TheoryId tid = Theory::theoryOf(atom);
+ if (tid == THEORY_ARITH)
{
- if (ret[i].isConst())
- {
- children.push_back(ret[i]);
- }
- else
- {
- all_const = false;
- break;
- }
+ new_ret = extendedRewriteArith(atom, pol);
}
- if (all_const)
+ // add back negation if not processed
+ if (!pol && !new_ret.isNull())
{
- Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL
- : (ret.getKind() == INTS_DIVISION
- ? INTS_DIVISION_TOTAL
- : INTS_MODULUS_TOTAL));
- new_ret = NodeManager::currentNM()->mkNode(new_k, children);
- Trace("q-ext-rewrite")
- << "sygus-extr : " << ret << " rewrites to " << new_ret
- << " due to total interpretation." << std::endl;
+ new_ret = new_ret.negate();
}
}
- // more expensive rewrites
+ //----------------------end theory-specific post-rewriting
+
+ //----------------------aggressive rewrites
if (new_ret.isNull() && d_aggr)
{
new_ret = extendedRewriteAggr(ret);
}
+ //----------------------end aggressive rewrites
- d_ext_rewrite_cache[n] = ret;
+ setCache(n, ret);
if (!new_ret.isNull())
{
ret = extendedRewrite(new_ret);
}
- d_ext_rewrite_cache[n] = ret;
+ Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
+ << std::endl;
+ setCache(n, ret);
return ret;
}
@@ -234,6 +211,8 @@ Node ExtendedRewriter::extendedRewriteAggr(Node n)
if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
|| ret_atom.getKind() == GEQ)
{
+ // ITE term removal in polynomials
+ // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
Trace("q-ext-rewrite-debug2")
<< "Compute monomial sum " << ret_atom << std::endl;
// compute monomial sum
@@ -255,7 +234,7 @@ Node ExtendedRewriter::extendedRewriteAggr(Node n)
Trace("q-ext-rewrite-debug")
<< " have ITE relation, solved form : " << veq << std::endl;
// try pulling ITE
- new_ret = extendedRewritePullIte(veq);
+ new_ret = extendedRewritePullIte(ITE, veq);
if (!new_ret.isNull())
{
if (!polarity)
@@ -279,10 +258,781 @@ Node ExtendedRewriter::extendedRewriteAggr(Node n)
<< " failed to get monomial sum of " << n << std::endl;
}
}
- // TODO (#1599) : conditional rewriting, condition merging
+ // TODO (#1706) : conditional rewriting, condition merging
+ return new_ret;
+}
+
+Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
+{
+ Assert(n.getKind() == itek);
+ Assert(n[1] != n[2]);
+
+ NodeManager* nm = NodeManager::currentNM();
+
+ Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
+
+ Node flip_cond;
+ if (n[0].getKind() == NOT)
+ {
+ flip_cond = n[0][0];
+ }
+ else if (n[0].getKind() == OR)
+ {
+ // a | b ---> ~( ~a & ~b )
+ flip_cond = TermUtil::simpleNegate(n[0]);
+ }
+ if (!flip_cond.isNull())
+ {
+ Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
+ // only print debug trace if full=true
+ if (full)
+ {
+ debugExtendedRewrite(n, new_ret, "ITE flip");
+ }
+ return new_ret;
+ }
+
+ // get entailed equalities in the condition
+ std::vector<Node> eq_conds;
+ Kind ck = n[0].getKind();
+ if (ck == EQUAL)
+ {
+ eq_conds.push_back(n[0]);
+ }
+ else if (ck == AND)
+ {
+ for (const Node& cn : n[0])
+ {
+ if (cn.getKind() == EQUAL)
+ {
+ eq_conds.push_back(cn);
+ }
+ }
+ }
+
+ Node new_ret;
+ Node b;
+ Node e;
+ Node t1 = n[1];
+ Node t2 = n[2];
+ std::stringstream ss_reason;
+
+ for (const Node& eq : eq_conds)
+ {
+ // simple invariant ITE
+ for (unsigned i = 0; i <= 1; i++)
+ {
+ // ite( x = y ^ C, y, x ) ---> x
+ // this is subsumed by the rewrites below
+ if (t2 == eq[i] && t1 == eq[1 - i])
+ {
+ new_ret = t2;
+ ss_reason << "ITE simple rev subs";
+ break;
+ }
+ }
+ if (!new_ret.isNull())
+ {
+ break;
+ }
+ }
+
+ if (new_ret.isNull() && d_aggr)
+ {
+ // If x is less than t based on an ordering, then we use { x -> t } as a
+ // substitution to the children of ite( x = t ^ C, s, t ) below.
+ std::vector<Node> vars;
+ std::vector<Node> subs;
+ for (const Node& eq : eq_conds)
+ {
+ inferSubstitution(eq, vars, subs);
+ }
+
+ if (!vars.empty())
+ {
+ // reverse substitution to opposite child
+ // r{ x -> t } = s implies ite( x=t ^ C, s, r ) ---> r
+ Node nn =
+ t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+ if (nn != t2)
+ {
+ nn = Rewriter::rewrite(nn);
+ if (nn == t1)
+ {
+ new_ret = t2;
+ ss_reason << "ITE rev subs";
+ }
+ }
+
+ // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
+ nn = t1.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+ if (nn != t1)
+ {
+ // If full=false, then we've duplicated a term u in the children of n.
+ // For example, when ITE pulling, we have n is of the form:
+ // ite( C, f( u, t1 ), f( u, t2 ) )
+ // We must show that at least one copy of u dissappears in this case.
+ nn = Rewriter::rewrite(nn);
+ if (nn == t2)
+ {
+ new_ret = nn;
+ ss_reason << "ITE subs invariant";
+ }
+ else if (full || nn.isConst())
+ {
+ new_ret = nm->mkNode(itek, n[0], nn, t2);
+ ss_reason << "ITE subs";
+ }
+ }
+ }
+ }
+
+ // only print debug trace if full=true
+ if (!new_ret.isNull() && full)
+ {
+ debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
+ }
+
+ return new_ret;
+}
+
+Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
+{
+ NodeManager* nm = NodeManager::currentNM();
+ TypeNode tn = n.getType();
+ std::vector<Node> children;
+ bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
+ if (hasOp)
+ {
+ children.push_back(n.getOperator());
+ }
+ unsigned nchildren = n.getNumChildren();
+ for (unsigned i = 0; i < nchildren; i++)
+ {
+ children.push_back(n[i]);
+ }
+ std::map<unsigned, std::map<unsigned, Node> > ite_c;
+ for (unsigned i = 0; i < nchildren; i++)
+ {
+ if (n[i].getKind() == itek)
+ {
+ unsigned ii = hasOp ? i + 1 : i;
+ for (unsigned j = 0; j < 2; j++)
+ {
+ children[ii] = n[i][j + 1];
+ Node pull = nm->mkNode(n.getKind(), children);
+ Node pullr = Rewriter::rewrite(pull);
+ children[ii] = n[i];
+ ite_c[i][j] = pullr;
+ }
+ if (ite_c[i][0] == ite_c[i][1])
+ {
+ // ITE dual invariance
+ // f( t1..s1..tn ) ---> t and f( t1..s2..tn ) ---> t implies
+ // f( t1..ite( A, s1, s2 )..tn ) ---> t
+ debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
+ return ite_c[i][0];
+ }
+ else if (d_aggr)
+ {
+ for (unsigned j = 0; j < 2; j++)
+ {
+ Node pullr = ite_c[i][j];
+ if (pullr.isConst() || pullr == n[i][j + 1])
+ {
+ // ITE single child elimination
+ // f( t1..s1..tn ) ---> t where t is a constant or s1 itself
+ // implies
+ // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
+ Node new_ret;
+ if (tn.isBoolean())
+ {
+ // remove false/true child immediately
+ bool pol = pullr.getConst<bool>();
+ std::vector<Node> new_children;
+ new_children.push_back((j == 0) == pol ? n[i][0]
+ : n[i][0].negate());
+ new_children.push_back(ite_c[i][1 - j]);
+ new_ret = nm->mkNode(pol ? OR : AND, new_children);
+ debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
+ }
+ else
+ {
+ new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
+ debugExtendedRewrite(n, new_ret, "ITE single elim");
+ }
+ return new_ret;
+ }
+ }
+ }
+ }
+ }
+
+ for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
+ {
+ Node nite = n[ip.first];
+ Assert(nite.getKind() == itek);
+ // now, simply pull the ITE and try ITE rewrites
+ Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
+ pull_ite = Rewriter::rewrite(pull_ite);
+ if (pull_ite.getKind() == ITE)
+ {
+ Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
+ if (!new_pull_ite.isNull())
+ {
+ debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
+ return new_pull_ite;
+ }
+ }
+ else
+ {
+ // A general rewrite could eliminate the ITE by pulling.
+ // An example is:
+ // ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
+ // ite( C, ~~x, ite( C, y, x ) ) --->
+ // x
+ // where ~ is bitvector negation.
+ debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
+ return pull_ite;
+ }
+ }
+
+ return Node::null();
+}
+
+Node ExtendedRewriter::extendedRewriteNnf(Node ret)
+{
+ Assert(ret.getKind() == NOT);
+
+ Kind nk = ret[0].getKind();
+ bool neg_ch = false;
+ bool neg_ch_1 = false;
+ if (nk == AND || nk == OR)
+ {
+ neg_ch = true;
+ nk = nk == AND ? OR : AND;
+ }
+ else if (nk == IMPLIES)
+ {
+ neg_ch = true;
+ neg_ch_1 = true;
+ nk = AND;
+ }
+ else if (nk == ITE)
+ {
+ neg_ch = true;
+ neg_ch_1 = true;
+ }
+ else if (nk == XOR)
+ {
+ nk = EQUAL;
+ }
+ else if (nk == EQUAL && ret[0][0].getType().isBoolean())
+ {
+ neg_ch_1 = true;
+ }
+ else
+ {
+ return Node::null();
+ }
+
+ std::vector<Node> new_children;
+ for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
+ {
+ Node c = ret[0][i];
+ c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
+ new_children.push_back(c);
+ }
+ return NodeManager::currentNM()->mkNode(nk, new_children);
+}
+
+Node ExtendedRewriter::extendedRewriteBcp(
+ Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node ret)
+{
+ Kind k = ret.getKind();
+ Assert(k == andk || k == ork);
+ Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
+
+ NodeManager* nm = NodeManager::currentNM();
+
+ TypeNode tn = ret.getType();
+ Node truen = TermUtil::mkTypeMaxValue(tn);
+ Node falsen = TermUtil::mkTypeValue(tn, 0);
+
+ // terms to process
+ std::vector<Node> to_process;
+ for (const Node& cn : ret)
+ {
+ to_process.push_back(cn);
+ }
+ // the processing terms
+ std::vector<Node> clauses;
+ // the terms we have propagated information to
+ std::unordered_set<Node, NodeHashFunction> prop_clauses;
+ // the assignment
+ std::map<Node, Node> assign;
+ std::vector<Node> avars;
+ std::vector<Node> asubs;
+
+ Kind ok = k == andk ? ork : andk;
+ // global polarity : when k=ork, everything is negated
+ bool gpol = k == andk;
+
+ do
+ {
+ // process the current nodes
+ while (!to_process.empty())
+ {
+ std::vector<Node> new_to_process;
+ for (const Node& cn : to_process)
+ {
+ Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
+ Kind cnk = cn.getKind();
+ bool pol = cnk != notk;
+ Node cln = cnk == notk ? cn[0] : cn;
+ Assert(cln.getKind() != notk);
+ if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
+ {
+ // flatten
+ for (const Node& ccln : cln)
+ {
+ Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
+ new_to_process.push_back(lccln);
+ }
+ }
+ else
+ {
+ // add it to the assignment
+ Node val = gpol == pol ? truen : falsen;
+ std::map<Node, Node>::iterator it = assign.find(cln);
+ Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
+ << std::endl;
+ if (it != assign.end())
+ {
+ if (val != it->second)
+ {
+ Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
+ // a conflicting assignment: we are done
+ return gpol ? falsen : truen;
+ }
+ }
+ else
+ {
+ assign[cln] = val;
+ avars.push_back(cln);
+ asubs.push_back(val);
+ }
+
+ // also, treat it as clause if possible
+ if (cln.getNumChildren() > 0
+ & (bcp_kinds.empty()
+ || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
+ {
+ if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
+ && prop_clauses.find(cn) == prop_clauses.end())
+ {
+ Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
+ clauses.push_back(cn);
+ }
+ }
+ }
+ }
+ to_process.clear();
+ to_process.insert(
+ to_process.end(), new_to_process.begin(), new_to_process.end());
+ }
+
+ // apply substitution to all subterms of clauses
+ std::vector<Node> new_clauses;
+ for (const Node& c : clauses)
+ {
+ bool cpol = c.getKind() != notk;
+ Node ca = c.getKind() == notk ? c[0] : c;
+ bool childChanged = false;
+ std::vector<Node> ccs_children;
+ for (const Node& cc : ca)
+ {
+ Node ccs = cc;
+ if (bcp_kinds.empty())
+ {
+ Trace("ext-rew-bcp-debug") << "...do ordinary substitute"
+ << std::endl;
+ ccs = cc.substitute(
+ avars.begin(), avars.end(), asubs.begin(), asubs.end());
+ }
+ else
+ {
+ Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
+ // substitution is only applicable to compatible kinds
+ ccs = partialSubstitute(ccs, assign, bcp_kinds);
+ }
+ childChanged = childChanged || ccs != cc;
+ ccs_children.push_back(ccs);
+ }
+ if (childChanged)
+ {
+ if (ca.getMetaKind() == metakind::PARAMETERIZED)
+ {
+ ccs_children.insert(ccs_children.begin(), ca.getOperator());
+ }
+ Node ccs = nm->mkNode(ca.getKind(), ccs_children);
+ ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
+ Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
+ << std::endl;
+ ccs = Rewriter::rewrite(ccs);
+ Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
+ to_process.push_back(ccs);
+ // store this as a node that propagation touched. This marks c so that
+ // it will not be included in the final construction.
+ prop_clauses.insert(ca);
+ }
+ else
+ {
+ new_clauses.push_back(c);
+ }
+ }
+ clauses.clear();
+ clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
+ } while (!to_process.empty());
+
+ // remake the node
+ if (!prop_clauses.empty())
+ {
+ std::vector<Node> children;
+ for (std::pair<const Node, Node>& l : assign)
+ {
+ Node a = l.first;
+ // if propagation did not touch a
+ if (prop_clauses.find(a) == prop_clauses.end())
+ {
+ Assert(l.second == truen || l.second == falsen);
+ Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
+ children.push_back(ln);
+ }
+ }
+ Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
+ Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
+ return new_ret;
+ }
+
+ return Node::null();
+}
+
+Node ExtendedRewriter::extendedRewriteEqChain(
+ Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor)
+{
+ Assert(ret.getKind() == eqk);
+
+ NodeManager* nm = NodeManager::currentNM();
+
+ TypeNode tn = ret[0].getType();
+
+ // sort/cancelling for Boolean EQUAL/XOR-chains
+ Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
+
+ // get the children on either side
+ bool gpol = true;
+ std::vector<Node> children;
+ for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
+ {
+ Node curr = ret[r];
+ // assume, if necessary, right associative
+ while (curr.getKind() == eqk && curr[0].getType() == tn)
+ {
+ children.push_back(curr[0]);
+ curr = curr[1];
+ }
+ children.push_back(curr);
+ }
+
+ std::map<Node, bool> cstatus;
+ // add children to status
+ for (const Node& c : children)
+ {
+ Node a = c;
+ if (a.getKind() == notk)
+ {
+ gpol = !gpol;
+ a = a[0];
+ }
+ Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
+ std::map<Node, bool>::iterator itc = cstatus.find(a);
+ if (itc == cstatus.end())
+ {
+ cstatus[a] = true;
+ }
+ else
+ {
+ // cancels
+ cstatus.erase(a);
+ if (isXor)
+ {
+ gpol = !gpol;
+ }
+ }
+ }
+ Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
+
+ if (cstatus.empty())
+ {
+ return TermUtil::mkTypeConst(tn, gpol);
+ }
+
+ children.clear();
+
+ // cancel AND/OR children if possible
+ for (std::pair<const Node, bool>& cp : cstatus)
+ {
+ if (cp.second)
+ {
+ Node c = cp.first;
+ Kind ck = c.getKind();
+ if (ck == andk || ck == ork)
+ {
+ for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
+ {
+ Node cl = c[j];
+ Node ca = cl.getKind() == notk ? cl[0] : cl;
+ bool capol = cl.getKind() != notk;
+ // if this already exists as a child of the equality chain
+ std::map<Node, bool>::iterator itc = cstatus.find(ca);
+ if (itc != cstatus.end() && itc->second)
+ {
+ // cancel it
+ cstatus[ca] = false;
+ cstatus[c] = false;
+ // make new child
+ // x = ( y | ~x ) ---> y & x
+ // x = ( y | x ) ---> ~y | x
+ // x = ( y & x ) ---> y | ~x
+ // x = ( y & ~x ) ---> ~y & ~x
+ std::vector<Node> new_children;
+ for (unsigned k = 0, nchild = c.getNumChildren(); k < nchild; k++)
+ {
+ if (j != k)
+ {
+ new_children.push_back(c[k]);
+ }
+ }
+ Node nc[2];
+ nc[0] = c[j];
+ nc[1] = new_children.size() == 1 ? new_children[0]
+ : nm->mkNode(ck, new_children);
+ // negate the proper child
+ unsigned nindex = (ck == andk) == capol ? 0 : 1;
+ nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
+ Kind nk = capol ? ork : andk;
+ // store as new child
+ children.push_back(nm->mkNode(nk, nc[0], nc[1]));
+ if (isXor)
+ {
+ gpol = !gpol;
+ }
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ // sorted right associative chain
+ bool has_const = false;
+ unsigned const_index = 0;
+ for (std::pair<const Node, bool>& cp : cstatus)
+ {
+ if (cp.second)
+ {
+ if (cp.first.isConst())
+ {
+ has_const = true;
+ const_index = children.size();
+ }
+ children.push_back(cp.first);
+ }
+ }
+ std::sort(children.begin(), children.end());
+
+ Node new_ret;
+ if (!gpol)
+ {
+ // negate the constant child if it exists
+ unsigned nindex = has_const ? const_index : 0;
+ children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
+ }
+ new_ret = children.back();
+ unsigned index = children.size() - 1;
+ while (index > 0)
+ {
+ index--;
+ new_ret = nm->mkNode(eqk, children[index], new_ret);
+ }
+ new_ret = Rewriter::rewrite(new_ret);
+ if (new_ret != ret)
+ {
+ return new_ret;
+ }
+ return Node::null();
+}
+
+Node ExtendedRewriter::partialSubstitute(Node n,
+ std::map<Node, Node>& assign,
+ std::map<Kind, bool>& rkinds)
+{
+ std::unordered_map<TNode, Node, TNodeHashFunction> visited;
+ std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
+ std::vector<TNode> visit;
+ TNode cur;
+ visit.push_back(n);
+ do
+ {
+ cur = visit.back();
+ visit.pop_back();
+ it = visited.find(cur);
+
+ if (it == visited.end())
+ {
+ std::map<Node, Node>::iterator it = assign.find(cur);
+ if (it != assign.end())
+ {
+ visited[cur] = it->second;
+ }
+ else
+ {
+ // can only recurse on these kinds
+ Kind k = cur.getKind();
+ if (rkinds.find(k) != rkinds.end())
+ {
+ visited[cur] = Node::null();
+ visit.push_back(cur);
+ for (const Node& cn : cur)
+ {
+ visit.push_back(cn);
+ }
+ }
+ else
+ {
+ visited[cur] = cur;
+ }
+ }
+ }
+ else if (it->second.isNull())
+ {
+ Node ret = cur;
+ bool childChanged = false;
+ std::vector<Node> children;
+ if (cur.getMetaKind() == metakind::PARAMETERIZED)
+ {
+ children.push_back(cur.getOperator());
+ }
+ for (const Node& cn : cur)
+ {
+ it = visited.find(cn);
+ Assert(it != visited.end());
+ Assert(!it->second.isNull());
+ childChanged = childChanged || cn != it->second;
+ children.push_back(it->second);
+ }
+ if (childChanged)
+ {
+ ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
+ }
+ visited[cur] = ret;
+ }
+ } while (!visit.empty());
+ Assert(visited.find(n) != visited.end());
+ Assert(!visited.find(n)->second.isNull());
+ return visited[n];
+}
+
+Node ExtendedRewriter::solveEquality(Node n)
+{
+ // TODO (#1706) : implement
+ Assert(n.getKind() == EQUAL);
+
+ return Node::null();
+}
+
+bool ExtendedRewriter::inferSubstitution(Node n,
+ std::vector<Node>& vars,
+ std::vector<Node>& subs)
+{
+ if (n.getKind() == EQUAL)
+ {
+ // see if it can be put into form x = y
+ Node slv_eq = solveEquality(n);
+ if (!slv_eq.isNull())
+ {
+ n = slv_eq;
+ }
+ for (unsigned i = 0; i < 2; i++)
+ {
+ TNode r1 = n[i];
+ TNode r2 = n[1 - i];
+ if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
+ {
+ // TODO (#1706) : union find
+ if (std::find(vars.begin(), vars.end(), r1) == vars.end())
+ {
+ vars.push_back(r1);
+ subs.push_back(r2);
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+Node ExtendedRewriter::extendedRewriteArith(Node ret, bool& pol)
+{
+ Kind k = ret.getKind();
+ NodeManager* nm = NodeManager::currentNM();
+ Node new_ret;
+ if (k == DIVISION || k == INTS_DIVISION || k == INTS_MODULUS)
+ {
+ // rewrite as though total
+ std::vector<Node> children;
+ bool all_const = true;
+ for (unsigned i = 0, size = ret.getNumChildren(); i < size; i++)
+ {
+ if (ret[i].isConst())
+ {
+ children.push_back(ret[i]);
+ }
+ else
+ {
+ all_const = false;
+ break;
+ }
+ }
+ if (all_const)
+ {
+ Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL
+ : (ret.getKind() == INTS_DIVISION
+ ? INTS_DIVISION_TOTAL
+ : INTS_MODULUS_TOTAL));
+ new_ret = nm->mkNode(new_k, children);
+ debugExtendedRewrite(ret, new_ret, "total-interpretation");
+ }
+ }
return new_ret;
}
+void ExtendedRewriter::debugExtendedRewrite(Node n,
+ Node ret,
+ const char* c) const
+{
+ if (Trace.isOn("q-ext-rewrite"))
+ {
+ if (!ret.isNull())
+ {
+ Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
+ Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
+ << " rewrites to " << ret << std::endl;
+ }
+ }
+}
+
} /* CVC4::theory::quantifiers namespace */
} /* CVC4::theory namespace */
} /* CVC4 namespace */
diff --git a/src/theory/quantifiers/extended_rewrite.h b/src/theory/quantifiers/extended_rewrite.h
index 25d710a6b..2daa42b18 100644
--- a/src/theory/quantifiers/extended_rewrite.h
+++ b/src/theory/quantifiers/extended_rewrite.h
@@ -35,10 +35,14 @@ namespace quantifiers {
*
* This class extended the standard techniques for rewriting
* with techniques, including but not limited to:
- * - ITE branch merging,
+ * - Redundant child elimination,
+ * - Sorting children of commutative operators,
+ * - Boolean constraint propagation,
+ * - Equality chain normalization,
+ * - Negation normal form,
+ * - Simple ITE pulling,
* - ITE conditional variable elimination,
- * - ITE condition subsumption, and
- * - Aggressive rewriting for string equalities.
+ * - ITE condition subsumption.
*/
class ExtendedRewriter
{
@@ -60,21 +64,128 @@ class ExtendedRewriter
* may be applied as a preprocessing step.
*/
bool d_aggr;
- /** true and false nodes */
- Node d_true;
- Node d_false;
- /** cache for extendedRewrite */
- std::unordered_map<Node, Node, NodeHashFunction> d_ext_rewrite_cache;
- /** pull ITE
- * Do simple ITE pulling, e.g.:
- * C2 --->^E false
- * implies:
- * ite( C, C1, C2 ) --->^E C ^ C1
- * where ---->^E denotes extended rewriting.
+ /** cache that the extended rewritten form of n is ret */
+ void setCache(Node n, Node ret);
+
+ //--------------------------------------generic utilities
+ /** Rewrite ITE, for example:
+ *
+ * ite( ~C, s, t ) ---> ite( C, t, s )
+ * ite( A or B, s, t ) ---> ite( ~A and ~B, t, s )
+ * ite( x = c, x, t ) --> ite( x = c, c, t )
+ * t * { x -> c } = s => ite( x = c, s, t ) ---> t
+ *
+ * The parameter "full" indicates an effort level that this rewrite will
+ * take. If full is false, then we do only perform rewrites that
+ * strictly decrease the term size of n.
+ */
+ Node extendedRewriteIte(Kind itek, Node n, bool full = true);
+ /** Pull ITE, for example:
+ *
+ * D=C2 ---> false
+ * implies
+ * D=ite( C, C1, C2 ) ---> C ^ D=C1
+ *
+ * f(t,t1) --> s and f(t,t2)---> s
+ * implies
+ * f(t,ite(C,t1,t2)) ---> s
+ *
+ * If this function returns a non-null node ret, then n ---> ret.
+ */
+ Node extendedRewritePullIte(Kind itek, Node n);
+ /** Negation Normal Form (NNF), for example:
+ *
+ * ~( A & B ) ---> ( ~ A | ~B )
+ * ~( ite( A, B, C ) ---> ite( A, ~B, ~C )
+ *
+ * If this function returns a non-null node ret, then n ---> ret.
+ */
+ Node extendedRewriteNnf(Node n);
+ /** (type-independent) Boolean constraint propagation, for example:
+ *
+ * ~A & ( B V A ) ---> ~A & B
+ * A & ( B = ( A V C ) ) ---> A & B
+ *
+ * This function takes as arguments the kinds that specify AND, OR, and NOT.
+ * It additionally takes as argument a map bcp_kinds. If this map is
+ * non-empty, then all terms that have a Kind that is *not* in this map should
+ * be treated as immutable. This is for instance to prevent propagation
+ * beneath illegal terms. As an example:
+ * (bvand A (bvor A B)) is equivalent to (bvand A (bvor 1...1 B)), but
+ * (bvand A (bvplus A B)) is not equivalent to (bvand A (bvplus 1..1 B)),
+ * hence, when using this function to do BCP for bit-vectors, we have that
+ * BITVECTOR_AND is a bcp_kind, but BITVECTOR_PLUS is not.
+ *
+ * If this function returns a non-null node ret, then n ---> ret.
+ */
+ Node extendedRewriteBcp(
+ Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node n);
+ /** (type-independent) Equality chain rewriting, for example:
+ *
+ * A = ( A = B ) ---> B
+ * ( A = D ) = ( C = B ) ---> A = ( B = ( C = D ) )
+ * A = ( A & B ) ---> ~A | B
+ *
+ * If this function returns a non-null node ret, then n ---> ret.
+ * This function takes as arguments the kinds that specify EQUAL, AND, OR,
+ * and NOT. If the flag isXor is true, the eqk is treated as XOR.
+ */
+ Node extendedRewriteEqChain(
+ Kind eqk, Kind andk, Kind ork, Kind notk, Node n, bool isXor = false);
+ /** extended rewrite aggressive
+ *
+ * All aggressive rewriting techniques (those that should be prioritized
+ * at a lower level) go in this function.
*/
- Node extendedRewritePullIte(Node n);
- /** extended rewrite aggressive */
Node extendedRewriteAggr(Node n);
+ /** Decompose right associative chain
+ *
+ * For term f( ... f( f( base, tn ), t{n-1} ) ... t1 ), returns term base, and
+ * appends t1...tn to children.
+ */
+ Node decomposeRightAssocChain(Kind k, Node n, std::vector<Node>& children);
+ /** Make right associative chain
+ *
+ * Sorts children to obtain list { tn...t1 }, and returns the term
+ * f( ... f( f( base, tn ), t{n-1} ) ... t1 ).
+ */
+ Node mkRightAssocChain(Kind k, Node base, std::vector<Node>& children);
+ /** Partial substitute
+ *
+ * Applies the substitution specified by assign to n, recursing only beneath
+ * terms whose Kind appears in rec_kinds.
+ */
+ Node partialSubstitute(Node n,
+ std::map<Node, Node>& assign,
+ std::map<Kind, bool>& rkinds);
+ /** solve equality
+ *
+ * If this function returns a non-null node n', then n' is equivalent to n
+ * and is of the form that can be used by inferSubstitution below.
+ */
+ Node solveEquality(Node n);
+ /** infer substitution
+ *
+ * If n is an equality of the form x = t, where t is either:
+ * (1) a constant, or
+ * (2) a variable y such that x < y based on an ordering,
+ * then this method adds x to vars and y to subs and return true, otherwise
+ * it returns false.
+ */
+ bool inferSubstitution(Node n,
+ std::vector<Node>& vars,
+ std::vector<Node>& subs);
+ /** extended rewrite
+ *
+ * Prints debug information, indicating the rewrite n ---> ret was found.
+ */
+ inline void debugExtendedRewrite(Node n, Node ret, const char* c) const;
+ //--------------------------------------end generic utilities
+
+ //--------------------------------------theory-specific top-level calls
+ /** extended rewrite arith */
+ Node extendedRewriteArith(Node ret, bool& pol);
+ //--------------------------------------end theory-specific top-level calls
};
} /* CVC4::theory::quantifiers namespace */
diff --git a/src/theory/quantifiers/term_util.cpp b/src/theory/quantifiers/term_util.cpp
index 3b8d03399..5965906cb 100644
--- a/src/theory/quantifiers/term_util.cpp
+++ b/src/theory/quantifiers/term_util.cpp
@@ -773,13 +773,22 @@ bool TermUtil::containsUninterpretedConstant( Node n ) {
Node TermUtil::simpleNegate( Node n ){
if( n.getKind()==OR || n.getKind()==AND ){
std::vector< Node > children;
- for( unsigned i=0; i<n.getNumChildren(); i++ ){
- children.push_back( simpleNegate( n[i] ) );
+ for (const Node& cn : n)
+ {
+ children.push_back(simpleNegate(cn));
}
return NodeManager::currentNM()->mkNode( n.getKind()==OR ? AND : OR, children );
- }else{
- return n.negate();
}
+ return n.negate();
+}
+
+Node TermUtil::mkNegate(Kind notk, Node n)
+{
+ if (n.getKind() == notk)
+ {
+ return n[0];
+ }
+ return NodeManager::currentNM()->mkNode(notk, n);
}
bool TermUtil::isAssoc( Kind k ) {
@@ -912,6 +921,11 @@ Node TermUtil::getTypeValueOffset(TypeNode tn,
return it->second;
}
+Node TermUtil::mkTypeConst(TypeNode tn, bool pol)
+{
+ return pol ? mkTypeValue(tn, 0) : mkTypeMaxValue(tn);
+}
+
bool TermUtil::isAntisymmetric(Kind k, Kind& dk)
{
if (k == GT)
diff --git a/src/theory/quantifiers/term_util.h b/src/theory/quantifiers/term_util.h
index 8ec2fc8e2..97f4edcd5 100644
--- a/src/theory/quantifiers/term_util.h
+++ b/src/theory/quantifiers/term_util.h
@@ -289,6 +289,11 @@ public:
static int getTermDepth( Node n );
/** simple negate */
static Node simpleNegate( Node n );
+ /**
+ * Make negated term, returns the negation of n wrt Kind notk, eliminating
+ * double negation if applicable, e.g. mkNegate( ~, ~x ) ---> x.
+ */
+ static Node mkNegate(Kind notk, Node n);
/** is assoc */
static bool isAssoc( Kind k );
/** is k commutative? */
@@ -364,6 +369,13 @@ public:
static Node mkTypeValueOffset(TypeNode tn, Node val, int offset, int& status);
/** make max value, static version of get max value */
static Node mkTypeMaxValue(TypeNode tn);
+ /**
+ * Make const, returns pol ? mkTypeValue(tn,0) : mkTypeMaxValue(tn).
+ * In other words, this returns either the minimum element of tn if pol is
+ * true, and the maximum element in pol is false. The type tn should have
+ * minimum and maximum elements, for example tn is Bool or BitVector.
+ */
+ static Node mkTypeConst(TypeNode tn, bool pol);
// for higher-order
private:
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback