diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2021-07-06 15:00:38 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-06 20:00:38 +0000 |
commit | b023494b8914be03d8f8ca26c1b1db332944f3fe (patch) | |
tree | a3381aab0b6cad11f2269df7cb9619b52de9d607 | |
parent | bb94bfd729be285c80aae5e97f41f813848f40cb (diff) |
Add implementation of learned rewrite pass (#6843)
-rw-r--r-- | src/preprocessing/passes/learned_rewrite.cpp | 287 |
1 files changed, 284 insertions, 3 deletions
diff --git a/src/preprocessing/passes/learned_rewrite.cpp b/src/preprocessing/passes/learned_rewrite.cpp index 785889666..4e9aa7bb2 100644 --- a/src/preprocessing/passes/learned_rewrite.cpp +++ b/src/preprocessing/passes/learned_rewrite.cpp @@ -76,7 +76,7 @@ PreprocessingPassResult LearnedRewrite::applyInternal( { // maybe use the literal for bound inference? Kind k = l.getKind(); - Assert (k != LT && k != GT && k != LEQ); + Assert(k != LT && k != GT && k != LEQ); if (k == EQUAL || k == GEQ) { binfer.add(l); @@ -155,14 +155,295 @@ Node LearnedRewrite::rewriteLearnedRec(Node n, std::unordered_set<Node>& lems, std::unordered_map<TNode, Node>& visited) { - return n; + NodeManager* nm = NodeManager::currentNM(); + std::unordered_map<TNode, Node>::iterator it; + std::vector<TNode> visit; + TNode cur; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + + if (it == visited.end()) + { + // mark pre-visited with null; will post-visit to construct final node + // in the block below. + visited[cur] = Node::null(); + visit.push_back(cur); + visit.insert(visit.end(), cur.begin(), cur.end()); + } + else if (it->second.isNull()) + { + Node ret = cur; + bool needsRcons = false; + std::vector<Node> children; + if (cur.getMetaKind() == kind::metakind::PARAMETERIZED) + { + children.push_back(cur.getOperator()); + } + for (const Node& cn : cur) + { + it = visited.find(cn); + Assert(it != visited.end()); + Assert(!it->second.isNull()); + needsRcons = needsRcons || cn != it->second; + children.push_back(it->second); + } + if (needsRcons) + { + ret = nm->mkNode(cur.getKind(), children); + } + // rewrite here + ret = rewriteLearned(ret, binfer, lems); + visited[cur] = ret; + } + } while (!visit.empty()); + Assert(visited.find(n) != visited.end()); + Assert(!visited.find(n)->second.isNull()); + return visited[n]; } Node LearnedRewrite::rewriteLearned(Node n, arith::BoundInference& binfer, std::unordered_set<Node>& lems) { - return n; + NodeManager* nm = NodeManager::currentNM(); + if (lems.find(n) != lems.end()) + { + // n is a learned literal: replace by true, not considered a rewrite + // for statistics + return nm->mkConst(true); + } + Trace("learned-rewrite-rr-debug") << "Rewrite " << n << std::endl; + Node nr = Rewriter::rewrite(n); + Kind k = nr.getKind(); + if (k == INTS_DIVISION || k == INTS_MODULUS || k == DIVISION) + { + // simpler if we know the divisor is non-zero + Node num = n[0]; + Node den = n[1]; + bool isNonZeroDen = false; + if (den.isConst()) + { + isNonZeroDen = (den.getConst<Rational>().sgn() != 0); + } + else + { + arith::Bounds db = binfer.get(den); + Trace("learned-rewrite-rr-debug") + << "Bounds for " << den << " : " << db.lower_value << " " + << db.upper_value << std::endl; + if (!db.lower_value.isNull() + && db.lower_value.getConst<Rational>().sgn() == 1) + { + isNonZeroDen = true; + } + else if (!db.upper_value.isNull() + && db.upper_value.getConst<Rational>().sgn() == -1) + { + isNonZeroDen = true; + } + } + if (isNonZeroDen) + { + Trace("learned-rewrite-rr-debug") + << "...non-zero denominator" << std::endl; + Kind nk = k; + switch (k) + { + case INTS_DIVISION: nk = INTS_DIVISION_TOTAL; break; + case INTS_MODULUS: nk = INTS_MODULUS_TOTAL; break; + case DIVISION: nk = DIVISION_TOTAL; break; + default: Assert(false); break; + } + std::vector<Node> children; + children.insert(children.end(), n.begin(), n.end()); + Node ret = nm->mkNode(nk, children); + nr = returnRewriteLearned(nr, ret, LearnedRewriteId::NON_ZERO_DEN); + nr = Rewriter::rewrite(nr); + k = nr.getKind(); + } + } + // constant int mod elimination by bound inference + if (k == INTS_MODULUS_TOTAL) + { + Node num = n[0]; + Node den = n[1]; + arith::Bounds db = binfer.get(den); + if ((!db.lower_value.isNull() + && db.lower_value.getConst<Rational>().sgn() == 1) + || (!db.upper_value.isNull() + && db.upper_value.getConst<Rational>().sgn() == -1)) + { + Rational bden = db.lower_value.isNull() + ? db.lower_value.getConst<Rational>() + : db.upper_value.getConst<Rational>().abs(); + // if 0 <= UB(num) < LB(den) or 0 <= UB(num) < -UB(den) + arith::Bounds nb = binfer.get(num); + if (!nb.upper_value.isNull()) + { + Rational bnum = nb.upper_value.getConst<Rational>(); + if (bnum.sgn() != -1 && bnum < bden) + { + nr = returnRewriteLearned(nr, nr[0], LearnedRewriteId::INT_MOD_RANGE); + } + } + // could also do num + k*den checks + } + } + else if (k == GEQ || (k == EQUAL && nr[0].getType().isReal())) + { + std::map<Node, Node> msum; + if (ArithMSum::getMonomialSumLit(nr, msum)) + { + Rational lb(0); + Rational ub(0); + bool lbSuccess = true; + bool ubSuccess = true; + Rational one(1); + if (Trace.isOn("learned-rewrite-arith-lit")) + { + Trace("learned-rewrite-arith-lit") + << "Arithmetic lit: " << nr << std::endl; + for (const std::pair<const Node, Node>& m : msum) + { + Trace("learned-rewrite-arith-lit") + << " " << m.first << ", " << m.second << std::endl; + } + } + for (const std::pair<const Node, Node>& m : msum) + { + bool isOneCoeff = m.second.isNull(); + Assert(isOneCoeff || m.second.isConst()); + if (m.first.isNull()) + { + lb = lb + (isOneCoeff ? one : m.second.getConst<Rational>()); + ub = ub + (isOneCoeff ? one : m.second.getConst<Rational>()); + } + else + { + arith::Bounds b = binfer.get(m.first); + bool isNeg = !isOneCoeff && m.second.getConst<Rational>().sgn() == -1; + // flip lower/upper if negative coefficient + TNode l = isNeg ? b.upper_value : b.lower_value; + TNode u = isNeg ? b.lower_value : b.upper_value; + if (lbSuccess && !l.isNull()) + { + Rational lc = l.getConst<Rational>(); + lb = lb + + (isOneCoeff ? lc + : Rational(lc * m.second.getConst<Rational>())); + } + else + { + lbSuccess = false; + } + if (ubSuccess && !u.isNull()) + { + Rational uc = u.getConst<Rational>(); + ub = ub + + (isOneCoeff ? uc + : Rational(uc * m.second.getConst<Rational>())); + } + else + { + ubSuccess = false; + } + if (!lbSuccess && !ubSuccess) + { + break; + } + } + } + if (lbSuccess) + { + if (lb.sgn() == 1) + { + // if positive lower bound, then GEQ is true, EQUAL is false + Node ret = nm->mkConst(k == GEQ); + nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_POS_LB); + return nr; + } + else if (lb.sgn() == 0 && k == GEQ) + { + // zero lower bound, GEQ is true + Node ret = nm->mkConst(true); + nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_ZERO_LB); + return nr; + } + } + else if (ubSuccess) + { + if (ub.sgn() == -1) + { + // if negative upper bound, then GEQ and EQUAL are false + Node ret = nm->mkConst(false); + nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_NEG_UB); + return nr; + } + } + // inferences based on combining div terms + Node currDen; + Node currNum; + std::vector<Node> sum; + size_t divCount = 0; + bool divTotal = true; + for (const std::pair<const Node, Node>& m : msum) + { + if (m.first.isNull()) + { + sum.push_back(m.second); + continue; + } + Kind mk = m.first.getKind(); + if (mk == INTS_DIVISION || mk == INTS_DIVISION_TOTAL) + { + Node factor = ArithMSum::mkCoeffTerm(m.second, m.first[0]); + divTotal = divTotal && mk == INTS_DIVISION_TOTAL; + divCount++; + if (currDen.isNull()) + { + currNum = factor; + currDen = m.first[1]; + } + else + { + factor = nm->mkNode(MULT, factor, currDen); + currNum = nm->mkNode(MULT, currNum, m.first[1]); + currNum = nm->mkNode(PLUS, currNum, factor); + currDen = nm->mkNode(MULT, currDen, m.first[1]); + } + } + else + { + Node factor = ArithMSum::mkCoeffTerm(m.second, m.first); + sum.push_back(factor); + } + } + if (divCount >= 2) + { + SkolemManager* sm = nm->getSkolemManager(); + Node r = sm->mkDummySkolem("r", nm->integerType()); + Node d = nm->mkNode( + divTotal ? INTS_DIVISION_TOTAL : INTS_DIVISION, currNum, currDen); + sum.push_back(d); + sum.push_back(r); + Node bound = + nm->mkNode(AND, + nm->mkNode(LEQ, nm->mkConst(-Rational(divCount - 1)), r), + nm->mkNode(LEQ, r, nm->mkConst(Rational(divCount - 1)))); + Node sumn = nm->mkNode(PLUS, sum); + Node lit = nm->mkNode(k, sumn, nm->mkConst(Rational(0))); + Node lemma = nm->mkNode(IMPLIES, nr, nm->mkNode(AND, lit, bound)); + Trace("learned-rewrite-div") + << "Div collect lemma: " << lemma << std::endl; + lems.insert(lemma); + } + } + } + return nr; } Node LearnedRewrite::returnRewriteLearned(Node n, Node nr, LearnedRewriteId id) |