diff options
Diffstat (limited to 'src/preprocessing/passes/learned_rewrite.cpp')
-rw-r--r-- | src/preprocessing/passes/learned_rewrite.cpp | 99 |
1 files changed, 29 insertions, 70 deletions
diff --git a/src/preprocessing/passes/learned_rewrite.cpp b/src/preprocessing/passes/learned_rewrite.cpp index 4e9aa7bb2..fd3cf832b 100644 --- a/src/preprocessing/passes/learned_rewrite.cpp +++ b/src/preprocessing/passes/learned_rewrite.cpp @@ -60,6 +60,7 @@ LearnedRewrite::LearnedRewrite(PreprocessingPassContext* preprocContext) PreprocessingPassResult LearnedRewrite::applyInternal( AssertionPipeline* assertionsToPreprocess) { + NodeManager* nm = NodeManager::currentNM(); arith::BoundInference binfer; std::vector<Node> learnedLits = d_preprocContext->getLearnedLiterals(); std::unordered_set<Node> llrw; @@ -72,14 +73,29 @@ PreprocessingPassResult LearnedRewrite::applyInternal( else { Trace("learned-rewrite-ll") << "Learned literals:" << std::endl; + std::map<Node, Node> originLit; for (const Node& l : learnedLits) { // maybe use the literal for bound inference? - Kind k = l.getKind(); - Assert(k != LT && k != GT && k != LEQ); - if (k == EQUAL || k == GEQ) + bool pol = l.getKind()!=NOT; + TNode atom = pol ? l : l[0]; + Kind ak = atom.getKind(); + Assert(ak != LT && ak != GT && ak != LEQ); + if ((ak == EQUAL && pol) || ak == GEQ) { - binfer.add(l); + // provide as < if negated >= + Node atomu; + if (!pol) + { + atomu = nm->mkNode(LT, atom[0], atom[1]); + originLit[atomu] = l; + } + else + { + atomu = l; + originLit[l] = l; + } + binfer.add(atomu); } Trace("learned-rewrite-ll") << "- " << l << std::endl; } @@ -93,7 +109,8 @@ PreprocessingPassResult LearnedRewrite::applyInternal( Node origin = i == 0 ? b.second.lower_origin : b.second.upper_origin; if (!origin.isNull()) { - llrw.insert(origin); + Assert (originLit.find(origin)!=originLit.end()); + llrw.insert(originLit[origin]); } } } @@ -139,7 +156,6 @@ PreprocessingPassResult LearnedRewrite::applyInternal( // unchanged. if (!llrw.empty()) { - NodeManager* nm = NodeManager::currentNM(); std::vector<Node> llrvec(llrw.begin(), llrw.end()); Node llc = nm->mkAnd(llrvec); Trace("learned-rewrite-assert") @@ -165,7 +181,13 @@ Node LearnedRewrite::rewriteLearnedRec(Node n, cur = visit.back(); visit.pop_back(); it = visited.find(cur); - + if (lems.find(cur) != lems.end()) + { + // n is a learned literal: replace by true, not considered a rewrite + // for statistics + visited[cur] = nm->mkConst(true); + continue; + } if (it == visited.end()) { // mark pre-visited with null; will post-visit to construct final node @@ -210,12 +232,6 @@ Node LearnedRewrite::rewriteLearned(Node n, std::unordered_set<Node>& lems) { NodeManager* nm = NodeManager::currentNM(); - if (lems.find(n) != lems.end()) - { - // n is a learned literal: replace by true, not considered a rewrite - // for statistics - return nm->mkConst(true); - } Trace("learned-rewrite-rr-debug") << "Rewrite " << n << std::endl; Node nr = Rewriter::rewrite(n); Kind k = nr.getKind(); @@ -384,63 +400,6 @@ Node LearnedRewrite::rewriteLearned(Node n, return nr; } } - // inferences based on combining div terms - Node currDen; - Node currNum; - std::vector<Node> sum; - size_t divCount = 0; - bool divTotal = true; - for (const std::pair<const Node, Node>& m : msum) - { - if (m.first.isNull()) - { - sum.push_back(m.second); - continue; - } - Kind mk = m.first.getKind(); - if (mk == INTS_DIVISION || mk == INTS_DIVISION_TOTAL) - { - Node factor = ArithMSum::mkCoeffTerm(m.second, m.first[0]); - divTotal = divTotal && mk == INTS_DIVISION_TOTAL; - divCount++; - if (currDen.isNull()) - { - currNum = factor; - currDen = m.first[1]; - } - else - { - factor = nm->mkNode(MULT, factor, currDen); - currNum = nm->mkNode(MULT, currNum, m.first[1]); - currNum = nm->mkNode(PLUS, currNum, factor); - currDen = nm->mkNode(MULT, currDen, m.first[1]); - } - } - else - { - Node factor = ArithMSum::mkCoeffTerm(m.second, m.first); - sum.push_back(factor); - } - } - if (divCount >= 2) - { - SkolemManager* sm = nm->getSkolemManager(); - Node r = sm->mkDummySkolem("r", nm->integerType()); - Node d = nm->mkNode( - divTotal ? INTS_DIVISION_TOTAL : INTS_DIVISION, currNum, currDen); - sum.push_back(d); - sum.push_back(r); - Node bound = - nm->mkNode(AND, - nm->mkNode(LEQ, nm->mkConst(-Rational(divCount - 1)), r), - nm->mkNode(LEQ, r, nm->mkConst(Rational(divCount - 1)))); - Node sumn = nm->mkNode(PLUS, sum); - Node lit = nm->mkNode(k, sumn, nm->mkConst(Rational(0))); - Node lemma = nm->mkNode(IMPLIES, nr, nm->mkNode(AND, lit, bound)); - Trace("learned-rewrite-div") - << "Div collect lemma: " << lemma << std::endl; - lems.insert(lemma); - } } } return nr; |