summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2021-07-06 17:44:09 -0500
committerGitHub <noreply@github.com>2021-07-06 17:44:09 -0500
commit4ac6c5179265ef9895bc9e939be0e47b3754137e (patch)
treebf2cd175a534e5ca9a383a63cd6a086f1d5f45d0
parentb023494b8914be03d8f8ca26c1b1db332944f3fe (diff)
Integrate learned rewrite preprocessing pass (#6840)
This adds the learned rewrite preprocessing pass, which rewrites the input formula based on (typically theory specific) reasoning about learned literals. The main motivation is for preprocessing ints division/modulus based on bounds.
-rw-r--r--src/CMakeLists.txt2
-rw-r--r--src/options/smt_options.toml8
-rw-r--r--src/preprocessing/passes/learned_rewrite.cpp99
-rw-r--r--src/preprocessing/preprocessing_pass_registry.cpp2
-rw-r--r--src/smt/process_assertions.cpp5
-rw-r--r--src/smt/set_defaults.cpp54
6 files changed, 76 insertions, 94 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 74db7c941..3246df654 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -88,6 +88,8 @@ libcvc5_add_sources(
preprocessing/passes/ite_removal.h
preprocessing/passes/ite_simp.cpp
preprocessing/passes/ite_simp.h
+ preprocessing/passes/learned_rewrite.cpp
+ preprocessing/passes/learned_rewrite.h
preprocessing/passes/miplib_trick.cpp
preprocessing/passes/miplib_trick.h
preprocessing/passes/nl_ext_purify.cpp
diff --git a/src/options/smt_options.toml b/src/options/smt_options.toml
index 4d08aa672..9b5a93486 100644
--- a/src/options/smt_options.toml
+++ b/src/options/smt_options.toml
@@ -48,6 +48,14 @@ name = "SMT Layer"
help = "use static learning (on by default)"
[[option]]
+ name = "learnedRewrite"
+ category = "regular"
+ long = "learned-rewrite"
+ type = "bool"
+ default = "false"
+ help = "rewrite the input based on learned literals"
+
+[[option]]
name = "expandDefinitions"
long = "expand-definitions"
category = "regular"
diff --git a/src/preprocessing/passes/learned_rewrite.cpp b/src/preprocessing/passes/learned_rewrite.cpp
index 4e9aa7bb2..fd3cf832b 100644
--- a/src/preprocessing/passes/learned_rewrite.cpp
+++ b/src/preprocessing/passes/learned_rewrite.cpp
@@ -60,6 +60,7 @@ LearnedRewrite::LearnedRewrite(PreprocessingPassContext* preprocContext)
PreprocessingPassResult LearnedRewrite::applyInternal(
AssertionPipeline* assertionsToPreprocess)
{
+ NodeManager* nm = NodeManager::currentNM();
arith::BoundInference binfer;
std::vector<Node> learnedLits = d_preprocContext->getLearnedLiterals();
std::unordered_set<Node> llrw;
@@ -72,14 +73,29 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
else
{
Trace("learned-rewrite-ll") << "Learned literals:" << std::endl;
+ std::map<Node, Node> originLit;
for (const Node& l : learnedLits)
{
// maybe use the literal for bound inference?
- Kind k = l.getKind();
- Assert(k != LT && k != GT && k != LEQ);
- if (k == EQUAL || k == GEQ)
+ bool pol = l.getKind()!=NOT;
+ TNode atom = pol ? l : l[0];
+ Kind ak = atom.getKind();
+ Assert(ak != LT && ak != GT && ak != LEQ);
+ if ((ak == EQUAL && pol) || ak == GEQ)
{
- binfer.add(l);
+ // provide as < if negated >=
+ Node atomu;
+ if (!pol)
+ {
+ atomu = nm->mkNode(LT, atom[0], atom[1]);
+ originLit[atomu] = l;
+ }
+ else
+ {
+ atomu = l;
+ originLit[l] = l;
+ }
+ binfer.add(atomu);
}
Trace("learned-rewrite-ll") << "- " << l << std::endl;
}
@@ -93,7 +109,8 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
Node origin = i == 0 ? b.second.lower_origin : b.second.upper_origin;
if (!origin.isNull())
{
- llrw.insert(origin);
+ Assert (originLit.find(origin)!=originLit.end());
+ llrw.insert(originLit[origin]);
}
}
}
@@ -139,7 +156,6 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
// unchanged.
if (!llrw.empty())
{
- NodeManager* nm = NodeManager::currentNM();
std::vector<Node> llrvec(llrw.begin(), llrw.end());
Node llc = nm->mkAnd(llrvec);
Trace("learned-rewrite-assert")
@@ -165,7 +181,13 @@ Node LearnedRewrite::rewriteLearnedRec(Node n,
cur = visit.back();
visit.pop_back();
it = visited.find(cur);
-
+ if (lems.find(cur) != lems.end())
+ {
+ // n is a learned literal: replace by true, not considered a rewrite
+ // for statistics
+ visited[cur] = nm->mkConst(true);
+ continue;
+ }
if (it == visited.end())
{
// mark pre-visited with null; will post-visit to construct final node
@@ -210,12 +232,6 @@ Node LearnedRewrite::rewriteLearned(Node n,
std::unordered_set<Node>& lems)
{
NodeManager* nm = NodeManager::currentNM();
- if (lems.find(n) != lems.end())
- {
- // n is a learned literal: replace by true, not considered a rewrite
- // for statistics
- return nm->mkConst(true);
- }
Trace("learned-rewrite-rr-debug") << "Rewrite " << n << std::endl;
Node nr = Rewriter::rewrite(n);
Kind k = nr.getKind();
@@ -384,63 +400,6 @@ Node LearnedRewrite::rewriteLearned(Node n,
return nr;
}
}
- // inferences based on combining div terms
- Node currDen;
- Node currNum;
- std::vector<Node> sum;
- size_t divCount = 0;
- bool divTotal = true;
- for (const std::pair<const Node, Node>& m : msum)
- {
- if (m.first.isNull())
- {
- sum.push_back(m.second);
- continue;
- }
- Kind mk = m.first.getKind();
- if (mk == INTS_DIVISION || mk == INTS_DIVISION_TOTAL)
- {
- Node factor = ArithMSum::mkCoeffTerm(m.second, m.first[0]);
- divTotal = divTotal && mk == INTS_DIVISION_TOTAL;
- divCount++;
- if (currDen.isNull())
- {
- currNum = factor;
- currDen = m.first[1];
- }
- else
- {
- factor = nm->mkNode(MULT, factor, currDen);
- currNum = nm->mkNode(MULT, currNum, m.first[1]);
- currNum = nm->mkNode(PLUS, currNum, factor);
- currDen = nm->mkNode(MULT, currDen, m.first[1]);
- }
- }
- else
- {
- Node factor = ArithMSum::mkCoeffTerm(m.second, m.first);
- sum.push_back(factor);
- }
- }
- if (divCount >= 2)
- {
- SkolemManager* sm = nm->getSkolemManager();
- Node r = sm->mkDummySkolem("r", nm->integerType());
- Node d = nm->mkNode(
- divTotal ? INTS_DIVISION_TOTAL : INTS_DIVISION, currNum, currDen);
- sum.push_back(d);
- sum.push_back(r);
- Node bound =
- nm->mkNode(AND,
- nm->mkNode(LEQ, nm->mkConst(-Rational(divCount - 1)), r),
- nm->mkNode(LEQ, r, nm->mkConst(Rational(divCount - 1))));
- Node sumn = nm->mkNode(PLUS, sum);
- Node lit = nm->mkNode(k, sumn, nm->mkConst(Rational(0)));
- Node lemma = nm->mkNode(IMPLIES, nr, nm->mkNode(AND, lit, bound));
- Trace("learned-rewrite-div")
- << "Div collect lemma: " << lemma << std::endl;
- lems.insert(lemma);
- }
}
}
return nr;
diff --git a/src/preprocessing/preprocessing_pass_registry.cpp b/src/preprocessing/preprocessing_pass_registry.cpp
index 6f846dc74..f0bd5af86 100644
--- a/src/preprocessing/preprocessing_pass_registry.cpp
+++ b/src/preprocessing/preprocessing_pass_registry.cpp
@@ -41,6 +41,7 @@
#include "preprocessing/passes/int_to_bv.h"
#include "preprocessing/passes/ite_removal.h"
#include "preprocessing/passes/ite_simp.h"
+#include "preprocessing/passes/learned_rewrite.h"
#include "preprocessing/passes/miplib_trick.h"
#include "preprocessing/passes/nl_ext_purify.h"
#include "preprocessing/passes/non_clausal_simp.h"
@@ -126,6 +127,7 @@ PreprocessingPassRegistry::PreprocessingPassRegistry()
registerPassInfo("global-negate", callCtor<GlobalNegate>);
registerPassInfo("int-to-bv", callCtor<IntToBV>);
registerPassInfo("bv-to-int", callCtor<BVToInt>);
+ registerPassInfo("learned-rewrite", callCtor<LearnedRewrite>);
registerPassInfo("foreign-theory-rewrite", callCtor<ForeignTheoryRewrite>);
registerPassInfo("synth-rr", callCtor<SynthRewRulesPass>);
registerPassInfo("real-to-int", callCtor<RealToInt>);
diff --git a/src/smt/process_assertions.cpp b/src/smt/process_assertions.cpp
index cf747c360..a9426d5bd 100644
--- a/src/smt/process_assertions.cpp
+++ b/src/smt/process_assertions.cpp
@@ -288,6 +288,11 @@ bool ProcessAssertions::apply(Assertions& as)
}
Debug("smt") << " assertions : " << assertions.size() << endl;
+ if (options::learnedRewrite())
+ {
+ d_passes["learned-rewrite"]->apply(&assertions);
+ }
+
if (options::earlyIteRemoval())
{
d_smtStats.d_numAssertionsPre += assertions.size();
diff --git a/src/smt/set_defaults.cpp b/src/smt/set_defaults.cpp
index ee3701d51..229fdeec5 100644
--- a/src/smt/set_defaults.cpp
+++ b/src/smt/set_defaults.cpp
@@ -489,15 +489,27 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
opts.smt.simplificationMode = options::SimplificationMode::NONE;
}
+ if (options::learnedRewrite())
+ {
+ if (opts.smt.learnedRewriteWasSetByUser)
+ {
+ throw OptionException(
+ "learned rewrites not supported with unsat cores");
+ }
+ Notice() << "SmtEngine: turning off learned rewrites to support "
+ "unsat cores\n";
+ opts.smt.learnedRewrite = false;
+ }
+
if (options::pbRewrites())
{
if (opts.arith.pbRewritesWasSetByUser)
{
throw OptionException(
- "pseudoboolean rewrites not supported with old unsat cores");
+ "pseudoboolean rewrites not supported with unsat cores");
}
Notice() << "SmtEngine: turning off pseudoboolean rewrites to support "
- "old unsat cores\n";
+ "unsat cores\n";
opts.arith.pbRewrites = false;
}
@@ -505,10 +517,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
{
if (opts.smt.sortInferenceWasSetByUser)
{
- throw OptionException(
- "sort inference not supported with old unsat cores");
+ throw OptionException("sort inference not supported with unsat cores");
}
- Notice() << "SmtEngine: turning off sort inference to support old unsat "
+ Notice() << "SmtEngine: turning off sort inference to support unsat "
"cores\n";
opts.smt.sortInference = false;
}
@@ -518,9 +529,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
if (opts.quantifiers.preSkolemQuantWasSetByUser)
{
throw OptionException(
- "pre-skolemization not supported with old unsat cores");
+ "pre-skolemization not supported with unsat cores");
}
- Notice() << "SmtEngine: turning off pre-skolemization to support old "
+ Notice() << "SmtEngine: turning off pre-skolemization to support "
"unsat cores\n";
opts.quantifiers.preSkolemQuant = false;
}
@@ -529,9 +540,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
{
if (opts.bv.bitvectorToBoolWasSetByUser)
{
- throw OptionException("bv-to-bool not supported with old unsat cores");
+ throw OptionException("bv-to-bool not supported with unsat cores");
}
- Notice() << "SmtEngine: turning off bitvector-to-bool to support old "
+ Notice() << "SmtEngine: turning off bitvector-to-bool to support "
"unsat cores\n";
opts.bv.bitvectorToBool = false;
}
@@ -541,10 +552,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
if (opts.bv.boolToBitvectorWasSetByUser)
{
throw OptionException(
- "bool-to-bv != off not supported with old unsat cores");
+ "bool-to-bv != off not supported with unsat cores");
}
- Notice()
- << "SmtEngine: turning off bool-to-bv to support old unsat cores\n";
+ Notice() << "SmtEngine: turning off bool-to-bv to support unsat cores\n";
opts.bv.boolToBitvector = options::BoolToBVMode::OFF;
}
@@ -552,11 +562,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
{
if (opts.bv.bvIntroducePow2WasSetByUser)
{
- throw OptionException(
- "bv-intro-pow2 not supported with old unsat cores");
+ throw OptionException("bv-intro-pow2 not supported with unsat cores");
}
- Notice()
- << "SmtEngine: turning off bv-intro-pow2 to support old unsat cores";
+ Notice() << "SmtEngine: turning off bv-intro-pow2 to support unsat cores";
opts.bv.bvIntroducePow2 = false;
}
@@ -564,10 +572,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
{
if (opts.smt.repeatSimpWasSetByUser)
{
- throw OptionException("repeat-simp not supported with old unsat cores");
+ throw OptionException("repeat-simp not supported with unsat cores");
}
- Notice()
- << "SmtEngine: turning off repeat-simp to support old unsat cores\n";
+ Notice() << "SmtEngine: turning off repeat-simp to support unsat cores\n";
opts.smt.repeatSimp = false;
}
@@ -575,22 +582,21 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
{
if (opts.quantifiers.globalNegateWasSetByUser)
{
- throw OptionException(
- "global-negate not supported with old unsat cores");
+ throw OptionException("global-negate not supported with unsat cores");
}
- Notice() << "SmtEngine: turning off global-negate to support old unsat "
+ Notice() << "SmtEngine: turning off global-negate to support unsat "
"cores\n";
opts.quantifiers.globalNegate = false;
}
if (options::bitvectorAig())
{
- throw OptionException("bitblast-aig not supported with old unsat cores");
+ throw OptionException("bitblast-aig not supported with unsat cores");
}
if (options::doITESimp())
{
- throw OptionException("ITE simp not supported with old unsat cores");
+ throw OptionException("ITE simp not supported with unsat cores");
}
}
else
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback