summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2021-07-06 15:00:38 -0500
committerGitHub <noreply@github.com>2021-07-06 20:00:38 +0000
commitb023494b8914be03d8f8ca26c1b1db332944f3fe (patch)
treea3381aab0b6cad11f2269df7cb9619b52de9d607
parentbb94bfd729be285c80aae5e97f41f813848f40cb (diff)
Add implementation of learned rewrite pass (#6843)
-rw-r--r--src/preprocessing/passes/learned_rewrite.cpp287
1 files changed, 284 insertions, 3 deletions
diff --git a/src/preprocessing/passes/learned_rewrite.cpp b/src/preprocessing/passes/learned_rewrite.cpp
index 785889666..4e9aa7bb2 100644
--- a/src/preprocessing/passes/learned_rewrite.cpp
+++ b/src/preprocessing/passes/learned_rewrite.cpp
@@ -76,7 +76,7 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
{
// maybe use the literal for bound inference?
Kind k = l.getKind();
- Assert (k != LT && k != GT && k != LEQ);
+ Assert(k != LT && k != GT && k != LEQ);
if (k == EQUAL || k == GEQ)
{
binfer.add(l);
@@ -155,14 +155,295 @@ Node LearnedRewrite::rewriteLearnedRec(Node n,
std::unordered_set<Node>& lems,
std::unordered_map<TNode, Node>& visited)
{
- return n;
+ NodeManager* nm = NodeManager::currentNM();
+ std::unordered_map<TNode, Node>::iterator it;
+ std::vector<TNode> visit;
+ TNode cur;
+ visit.push_back(n);
+ do
+ {
+ cur = visit.back();
+ visit.pop_back();
+ it = visited.find(cur);
+
+ if (it == visited.end())
+ {
+ // mark pre-visited with null; will post-visit to construct final node
+ // in the block below.
+ visited[cur] = Node::null();
+ visit.push_back(cur);
+ visit.insert(visit.end(), cur.begin(), cur.end());
+ }
+ else if (it->second.isNull())
+ {
+ Node ret = cur;
+ bool needsRcons = false;
+ std::vector<Node> children;
+ if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
+ {
+ children.push_back(cur.getOperator());
+ }
+ for (const Node& cn : cur)
+ {
+ it = visited.find(cn);
+ Assert(it != visited.end());
+ Assert(!it->second.isNull());
+ needsRcons = needsRcons || cn != it->second;
+ children.push_back(it->second);
+ }
+ if (needsRcons)
+ {
+ ret = nm->mkNode(cur.getKind(), children);
+ }
+ // rewrite here
+ ret = rewriteLearned(ret, binfer, lems);
+ visited[cur] = ret;
+ }
+ } while (!visit.empty());
+ Assert(visited.find(n) != visited.end());
+ Assert(!visited.find(n)->second.isNull());
+ return visited[n];
}
Node LearnedRewrite::rewriteLearned(Node n,
arith::BoundInference& binfer,
std::unordered_set<Node>& lems)
{
- return n;
+ NodeManager* nm = NodeManager::currentNM();
+ if (lems.find(n) != lems.end())
+ {
+ // n is a learned literal: replace by true, not considered a rewrite
+ // for statistics
+ return nm->mkConst(true);
+ }
+ Trace("learned-rewrite-rr-debug") << "Rewrite " << n << std::endl;
+ Node nr = Rewriter::rewrite(n);
+ Kind k = nr.getKind();
+ if (k == INTS_DIVISION || k == INTS_MODULUS || k == DIVISION)
+ {
+ // simpler if we know the divisor is non-zero
+ Node num = n[0];
+ Node den = n[1];
+ bool isNonZeroDen = false;
+ if (den.isConst())
+ {
+ isNonZeroDen = (den.getConst<Rational>().sgn() != 0);
+ }
+ else
+ {
+ arith::Bounds db = binfer.get(den);
+ Trace("learned-rewrite-rr-debug")
+ << "Bounds for " << den << " : " << db.lower_value << " "
+ << db.upper_value << std::endl;
+ if (!db.lower_value.isNull()
+ && db.lower_value.getConst<Rational>().sgn() == 1)
+ {
+ isNonZeroDen = true;
+ }
+ else if (!db.upper_value.isNull()
+ && db.upper_value.getConst<Rational>().sgn() == -1)
+ {
+ isNonZeroDen = true;
+ }
+ }
+ if (isNonZeroDen)
+ {
+ Trace("learned-rewrite-rr-debug")
+ << "...non-zero denominator" << std::endl;
+ Kind nk = k;
+ switch (k)
+ {
+ case INTS_DIVISION: nk = INTS_DIVISION_TOTAL; break;
+ case INTS_MODULUS: nk = INTS_MODULUS_TOTAL; break;
+ case DIVISION: nk = DIVISION_TOTAL; break;
+ default: Assert(false); break;
+ }
+ std::vector<Node> children;
+ children.insert(children.end(), n.begin(), n.end());
+ Node ret = nm->mkNode(nk, children);
+ nr = returnRewriteLearned(nr, ret, LearnedRewriteId::NON_ZERO_DEN);
+ nr = Rewriter::rewrite(nr);
+ k = nr.getKind();
+ }
+ }
+ // constant int mod elimination by bound inference
+ if (k == INTS_MODULUS_TOTAL)
+ {
+ Node num = n[0];
+ Node den = n[1];
+ arith::Bounds db = binfer.get(den);
+ if ((!db.lower_value.isNull()
+ && db.lower_value.getConst<Rational>().sgn() == 1)
+ || (!db.upper_value.isNull()
+ && db.upper_value.getConst<Rational>().sgn() == -1))
+ {
+ Rational bden = db.lower_value.isNull()
+ ? db.lower_value.getConst<Rational>()
+ : db.upper_value.getConst<Rational>().abs();
+ // if 0 <= UB(num) < LB(den) or 0 <= UB(num) < -UB(den)
+ arith::Bounds nb = binfer.get(num);
+ if (!nb.upper_value.isNull())
+ {
+ Rational bnum = nb.upper_value.getConst<Rational>();
+ if (bnum.sgn() != -1 && bnum < bden)
+ {
+ nr = returnRewriteLearned(nr, nr[0], LearnedRewriteId::INT_MOD_RANGE);
+ }
+ }
+ // could also do num + k*den checks
+ }
+ }
+ else if (k == GEQ || (k == EQUAL && nr[0].getType().isReal()))
+ {
+ std::map<Node, Node> msum;
+ if (ArithMSum::getMonomialSumLit(nr, msum))
+ {
+ Rational lb(0);
+ Rational ub(0);
+ bool lbSuccess = true;
+ bool ubSuccess = true;
+ Rational one(1);
+ if (Trace.isOn("learned-rewrite-arith-lit"))
+ {
+ Trace("learned-rewrite-arith-lit")
+ << "Arithmetic lit: " << nr << std::endl;
+ for (const std::pair<const Node, Node>& m : msum)
+ {
+ Trace("learned-rewrite-arith-lit")
+ << " " << m.first << ", " << m.second << std::endl;
+ }
+ }
+ for (const std::pair<const Node, Node>& m : msum)
+ {
+ bool isOneCoeff = m.second.isNull();
+ Assert(isOneCoeff || m.second.isConst());
+ if (m.first.isNull())
+ {
+ lb = lb + (isOneCoeff ? one : m.second.getConst<Rational>());
+ ub = ub + (isOneCoeff ? one : m.second.getConst<Rational>());
+ }
+ else
+ {
+ arith::Bounds b = binfer.get(m.first);
+ bool isNeg = !isOneCoeff && m.second.getConst<Rational>().sgn() == -1;
+ // flip lower/upper if negative coefficient
+ TNode l = isNeg ? b.upper_value : b.lower_value;
+ TNode u = isNeg ? b.lower_value : b.upper_value;
+ if (lbSuccess && !l.isNull())
+ {
+ Rational lc = l.getConst<Rational>();
+ lb = lb
+ + (isOneCoeff ? lc
+ : Rational(lc * m.second.getConst<Rational>()));
+ }
+ else
+ {
+ lbSuccess = false;
+ }
+ if (ubSuccess && !u.isNull())
+ {
+ Rational uc = u.getConst<Rational>();
+ ub = ub
+ + (isOneCoeff ? uc
+ : Rational(uc * m.second.getConst<Rational>()));
+ }
+ else
+ {
+ ubSuccess = false;
+ }
+ if (!lbSuccess && !ubSuccess)
+ {
+ break;
+ }
+ }
+ }
+ if (lbSuccess)
+ {
+ if (lb.sgn() == 1)
+ {
+ // if positive lower bound, then GEQ is true, EQUAL is false
+ Node ret = nm->mkConst(k == GEQ);
+ nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_POS_LB);
+ return nr;
+ }
+ else if (lb.sgn() == 0 && k == GEQ)
+ {
+ // zero lower bound, GEQ is true
+ Node ret = nm->mkConst(true);
+ nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_ZERO_LB);
+ return nr;
+ }
+ }
+ else if (ubSuccess)
+ {
+ if (ub.sgn() == -1)
+ {
+ // if negative upper bound, then GEQ and EQUAL are false
+ Node ret = nm->mkConst(false);
+ nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_NEG_UB);
+ return nr;
+ }
+ }
+ // inferences based on combining div terms
+ Node currDen;
+ Node currNum;
+ std::vector<Node> sum;
+ size_t divCount = 0;
+ bool divTotal = true;
+ for (const std::pair<const Node, Node>& m : msum)
+ {
+ if (m.first.isNull())
+ {
+ sum.push_back(m.second);
+ continue;
+ }
+ Kind mk = m.first.getKind();
+ if (mk == INTS_DIVISION || mk == INTS_DIVISION_TOTAL)
+ {
+ Node factor = ArithMSum::mkCoeffTerm(m.second, m.first[0]);
+ divTotal = divTotal && mk == INTS_DIVISION_TOTAL;
+ divCount++;
+ if (currDen.isNull())
+ {
+ currNum = factor;
+ currDen = m.first[1];
+ }
+ else
+ {
+ factor = nm->mkNode(MULT, factor, currDen);
+ currNum = nm->mkNode(MULT, currNum, m.first[1]);
+ currNum = nm->mkNode(PLUS, currNum, factor);
+ currDen = nm->mkNode(MULT, currDen, m.first[1]);
+ }
+ }
+ else
+ {
+ Node factor = ArithMSum::mkCoeffTerm(m.second, m.first);
+ sum.push_back(factor);
+ }
+ }
+ if (divCount >= 2)
+ {
+ SkolemManager* sm = nm->getSkolemManager();
+ Node r = sm->mkDummySkolem("r", nm->integerType());
+ Node d = nm->mkNode(
+ divTotal ? INTS_DIVISION_TOTAL : INTS_DIVISION, currNum, currDen);
+ sum.push_back(d);
+ sum.push_back(r);
+ Node bound =
+ nm->mkNode(AND,
+ nm->mkNode(LEQ, nm->mkConst(-Rational(divCount - 1)), r),
+ nm->mkNode(LEQ, r, nm->mkConst(Rational(divCount - 1))));
+ Node sumn = nm->mkNode(PLUS, sum);
+ Node lit = nm->mkNode(k, sumn, nm->mkConst(Rational(0)));
+ Node lemma = nm->mkNode(IMPLIES, nr, nm->mkNode(AND, lit, bound));
+ Trace("learned-rewrite-div")
+ << "Div collect lemma: " << lemma << std::endl;
+ lems.insert(lemma);
+ }
+ }
+ }
+ return nr;
}
Node LearnedRewrite::returnRewriteLearned(Node n, Node nr, LearnedRewriteId id)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback