summaryrefslogtreecommitdiff
path: root/src/preprocessing/passes/learned_rewrite.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/preprocessing/passes/learned_rewrite.cpp')
-rw-r--r--src/preprocessing/passes/learned_rewrite.cpp99
1 files changed, 29 insertions, 70 deletions
diff --git a/src/preprocessing/passes/learned_rewrite.cpp b/src/preprocessing/passes/learned_rewrite.cpp
index 4e9aa7bb2..fd3cf832b 100644
--- a/src/preprocessing/passes/learned_rewrite.cpp
+++ b/src/preprocessing/passes/learned_rewrite.cpp
@@ -60,6 +60,7 @@ LearnedRewrite::LearnedRewrite(PreprocessingPassContext* preprocContext)
PreprocessingPassResult LearnedRewrite::applyInternal(
AssertionPipeline* assertionsToPreprocess)
{
+ NodeManager* nm = NodeManager::currentNM();
arith::BoundInference binfer;
std::vector<Node> learnedLits = d_preprocContext->getLearnedLiterals();
std::unordered_set<Node> llrw;
@@ -72,14 +73,29 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
else
{
Trace("learned-rewrite-ll") << "Learned literals:" << std::endl;
+ std::map<Node, Node> originLit;
for (const Node& l : learnedLits)
{
// maybe use the literal for bound inference?
- Kind k = l.getKind();
- Assert(k != LT && k != GT && k != LEQ);
- if (k == EQUAL || k == GEQ)
+ bool pol = l.getKind()!=NOT;
+ TNode atom = pol ? l : l[0];
+ Kind ak = atom.getKind();
+ Assert(ak != LT && ak != GT && ak != LEQ);
+ if ((ak == EQUAL && pol) || ak == GEQ)
{
- binfer.add(l);
+ // provide as < if negated >=
+ Node atomu;
+ if (!pol)
+ {
+ atomu = nm->mkNode(LT, atom[0], atom[1]);
+ originLit[atomu] = l;
+ }
+ else
+ {
+ atomu = l;
+ originLit[l] = l;
+ }
+ binfer.add(atomu);
}
Trace("learned-rewrite-ll") << "- " << l << std::endl;
}
@@ -93,7 +109,8 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
Node origin = i == 0 ? b.second.lower_origin : b.second.upper_origin;
if (!origin.isNull())
{
- llrw.insert(origin);
+ Assert (originLit.find(origin)!=originLit.end());
+ llrw.insert(originLit[origin]);
}
}
}
@@ -139,7 +156,6 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
// unchanged.
if (!llrw.empty())
{
- NodeManager* nm = NodeManager::currentNM();
std::vector<Node> llrvec(llrw.begin(), llrw.end());
Node llc = nm->mkAnd(llrvec);
Trace("learned-rewrite-assert")
@@ -165,7 +181,13 @@ Node LearnedRewrite::rewriteLearnedRec(Node n,
cur = visit.back();
visit.pop_back();
it = visited.find(cur);
-
+ if (lems.find(cur) != lems.end())
+ {
+ // n is a learned literal: replace by true, not considered a rewrite
+ // for statistics
+ visited[cur] = nm->mkConst(true);
+ continue;
+ }
if (it == visited.end())
{
// mark pre-visited with null; will post-visit to construct final node
@@ -210,12 +232,6 @@ Node LearnedRewrite::rewriteLearned(Node n,
std::unordered_set<Node>& lems)
{
NodeManager* nm = NodeManager::currentNM();
- if (lems.find(n) != lems.end())
- {
- // n is a learned literal: replace by true, not considered a rewrite
- // for statistics
- return nm->mkConst(true);
- }
Trace("learned-rewrite-rr-debug") << "Rewrite " << n << std::endl;
Node nr = Rewriter::rewrite(n);
Kind k = nr.getKind();
@@ -384,63 +400,6 @@ Node LearnedRewrite::rewriteLearned(Node n,
return nr;
}
}
- // inferences based on combining div terms
- Node currDen;
- Node currNum;
- std::vector<Node> sum;
- size_t divCount = 0;
- bool divTotal = true;
- for (const std::pair<const Node, Node>& m : msum)
- {
- if (m.first.isNull())
- {
- sum.push_back(m.second);
- continue;
- }
- Kind mk = m.first.getKind();
- if (mk == INTS_DIVISION || mk == INTS_DIVISION_TOTAL)
- {
- Node factor = ArithMSum::mkCoeffTerm(m.second, m.first[0]);
- divTotal = divTotal && mk == INTS_DIVISION_TOTAL;
- divCount++;
- if (currDen.isNull())
- {
- currNum = factor;
- currDen = m.first[1];
- }
- else
- {
- factor = nm->mkNode(MULT, factor, currDen);
- currNum = nm->mkNode(MULT, currNum, m.first[1]);
- currNum = nm->mkNode(PLUS, currNum, factor);
- currDen = nm->mkNode(MULT, currDen, m.first[1]);
- }
- }
- else
- {
- Node factor = ArithMSum::mkCoeffTerm(m.second, m.first);
- sum.push_back(factor);
- }
- }
- if (divCount >= 2)
- {
- SkolemManager* sm = nm->getSkolemManager();
- Node r = sm->mkDummySkolem("r", nm->integerType());
- Node d = nm->mkNode(
- divTotal ? INTS_DIVISION_TOTAL : INTS_DIVISION, currNum, currDen);
- sum.push_back(d);
- sum.push_back(r);
- Node bound =
- nm->mkNode(AND,
- nm->mkNode(LEQ, nm->mkConst(-Rational(divCount - 1)), r),
- nm->mkNode(LEQ, r, nm->mkConst(Rational(divCount - 1))));
- Node sumn = nm->mkNode(PLUS, sum);
- Node lit = nm->mkNode(k, sumn, nm->mkConst(Rational(0)));
- Node lemma = nm->mkNode(IMPLIES, nr, nm->mkNode(AND, lit, bound));
- Trace("learned-rewrite-div")
- << "Div collect lemma: " << lemma << std::endl;
- lems.insert(lemma);
- }
}
}
return nr;
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback