summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2021-07-06 12:23:45 -0500
committerGitHub <noreply@github.com>2021-07-06 17:23:45 +0000
commitbb94bfd729be285c80aae5e97f41f813848f40cb (patch)
tree1bb1c291f82efa719d3997dacfbf5e6de5097013 /src
parentc05fe825c6370a3f6bfe8c8264634d11b398567f (diff)
Add learned rewrite preprocessing pass (#6842)
Adds the basic skeleton of the pass.
Diffstat (limited to 'src')
-rw-r--r--src/preprocessing/passes/learned_rewrite.cpp181
-rw-r--r--src/preprocessing/passes/learned_rewrite.h108
2 files changed, 289 insertions, 0 deletions
diff --git a/src/preprocessing/passes/learned_rewrite.cpp b/src/preprocessing/passes/learned_rewrite.cpp
new file mode 100644
index 000000000..785889666
--- /dev/null
+++ b/src/preprocessing/passes/learned_rewrite.cpp
@@ -0,0 +1,181 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ * Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved. See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Rewriting based on learned literals
+ */
+
+#include "preprocessing/passes/learned_rewrite.h"
+
+#include "expr/skolem_manager.h"
+#include "expr/term_context_stack.h"
+#include "preprocessing/assertion_pipeline.h"
+#include "smt/smt_statistics_registry.h"
+#include "theory/arith/arith_msum.h"
+#include "theory/rewriter.h"
+#include "util/rational.h"
+
+using namespace cvc5::theory;
+using namespace cvc5::kind;
+
+namespace cvc5 {
+namespace preprocessing {
+namespace passes {
+
+const char* toString(LearnedRewriteId i)
+{
+ switch (i)
+ {
+ case LearnedRewriteId::NON_ZERO_DEN: return "NON_ZERO_DEN";
+ case LearnedRewriteId::INT_MOD_RANGE: return "INT_MOD_RANGE";
+ case LearnedRewriteId::PRED_POS_LB: return "PRED_POS_LB";
+ case LearnedRewriteId::PRED_ZERO_LB: return "PRED_ZERO_LB";
+ case LearnedRewriteId::PRED_NEG_UB: return "PRED_NEG_UB";
+ case LearnedRewriteId::NONE: return "NONE";
+ default: return "?LearnedRewriteId?";
+ }
+}
+
+std::ostream& operator<<(std::ostream& out, LearnedRewriteId i)
+{
+ out << toString(i);
+ return out;
+}
+
+LearnedRewrite::LearnedRewrite(PreprocessingPassContext* preprocContext)
+ : PreprocessingPass(preprocContext, "learned-rewrite"),
+ d_lrewCount(smtStatisticsRegistry().registerHistogram<LearnedRewriteId>(
+ "LearnedRewrite::lrewCount"))
+{
+}
+
+PreprocessingPassResult LearnedRewrite::applyInternal(
+ AssertionPipeline* assertionsToPreprocess)
+{
+ arith::BoundInference binfer;
+ std::vector<Node> learnedLits = d_preprocContext->getLearnedLiterals();
+ std::unordered_set<Node> llrw;
+ std::unordered_map<TNode, Node> visited;
+ if (learnedLits.empty())
+ {
+ Trace("learned-rewrite-ll") << "No learned literals" << std::endl;
+ return PreprocessingPassResult::NO_CONFLICT;
+ }
+ else
+ {
+ Trace("learned-rewrite-ll") << "Learned literals:" << std::endl;
+ for (const Node& l : learnedLits)
+ {
+ // maybe use the literal for bound inference?
+ Kind k = l.getKind();
+ Assert (k != LT && k != GT && k != LEQ);
+ if (k == EQUAL || k == GEQ)
+ {
+ binfer.add(l);
+ }
+ Trace("learned-rewrite-ll") << "- " << l << std::endl;
+ }
+ const std::map<Node, arith::Bounds>& bs = binfer.get();
+ // get the literals that were critical, i.e. used in the derivation of a
+ // bound
+ for (const std::pair<const Node, arith::Bounds>& b : bs)
+ {
+ for (size_t i = 0; i < 2; i++)
+ {
+ Node origin = i == 0 ? b.second.lower_origin : b.second.upper_origin;
+ if (!origin.isNull())
+ {
+ llrw.insert(origin);
+ }
+ }
+ }
+ // rewrite the non-critical learned literals, some may be redundant
+ for (const Node& l : learnedLits)
+ {
+ if (llrw.find(l) != llrw.end())
+ {
+ continue;
+ }
+ Node e = rewriteLearnedRec(l, binfer, llrw, visited);
+ if (e.isConst())
+ {
+ // ignore true
+ if (e.getConst<bool>())
+ {
+ continue;
+ }
+ // conflict, we are done
+ assertionsToPreprocess->push_back(e);
+ return PreprocessingPassResult::CONFLICT;
+ }
+ llrw.insert(e);
+ }
+ Trace("learned-rewrite-ll") << "end" << std::endl;
+ }
+ size_t size = assertionsToPreprocess->size();
+ for (size_t i = 0; i < size; ++i)
+ {
+ Node prev = (*assertionsToPreprocess)[i];
+ Trace("learned-rewrite-assert")
+ << "LearnedRewrite: assert: " << prev << std::endl;
+ Node e = rewriteLearnedRec(prev, binfer, llrw, visited);
+ if (e != prev)
+ {
+ Trace("learned-rewrite-assert")
+ << ".......................: " << e << std::endl;
+ assertionsToPreprocess->replace(i, e);
+ }
+ }
+ // Add the conjunction of learned literals back to assertions. Notice that
+ // in some cases we may add top-level assertions back to the assertion list
+ // unchanged.
+ if (!llrw.empty())
+ {
+ NodeManager* nm = NodeManager::currentNM();
+ std::vector<Node> llrvec(llrw.begin(), llrw.end());
+ Node llc = nm->mkAnd(llrvec);
+ Trace("learned-rewrite-assert")
+ << "Re-add rewritten learned conjunction: " << llc << std::endl;
+ assertionsToPreprocess->push_back(llc);
+ }
+
+ return PreprocessingPassResult::NO_CONFLICT;
+}
+
+Node LearnedRewrite::rewriteLearnedRec(Node n,
+ arith::BoundInference& binfer,
+ std::unordered_set<Node>& lems,
+ std::unordered_map<TNode, Node>& visited)
+{
+ return n;
+}
+
+Node LearnedRewrite::rewriteLearned(Node n,
+ arith::BoundInference& binfer,
+ std::unordered_set<Node>& lems)
+{
+ return n;
+}
+
+Node LearnedRewrite::returnRewriteLearned(Node n, Node nr, LearnedRewriteId id)
+{
+ if (Trace.isOn("learned-rewrite"))
+ {
+ Trace("learned-rewrite") << "LearnedRewrite::Rewrite: (" << id << ") " << n
+ << " == " << nr << std::endl;
+ }
+ d_lrewCount << id;
+ return nr;
+}
+
+} // namespace passes
+} // namespace preprocessing
+} // namespace cvc5
diff --git a/src/preprocessing/passes/learned_rewrite.h b/src/preprocessing/passes/learned_rewrite.h
new file mode 100644
index 000000000..4f3a51d58
--- /dev/null
+++ b/src/preprocessing/passes/learned_rewrite.h
@@ -0,0 +1,108 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ * Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved. See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Rewriting based on learned literals
+ */
+#include "cvc5_private.h"
+
+#ifndef CVC5__PREPROCESSING__PASSES__LEARNED_REWRITE_H
+#define CVC5__PREPROCESSING__PASSES__LEARNED_REWRITE_H
+
+#include "preprocessing/preprocessing_pass.h"
+#include "preprocessing/preprocessing_pass_context.h"
+#include "theory/arith/bound_inference.h"
+#include "util/statistics_stats.h"
+
+#include <iosfwd>
+
+namespace cvc5 {
+namespace preprocessing {
+namespace passes {
+
+/**
+ * Learned rewrites in the pass below.
+ */
+enum class LearnedRewriteId
+{
+ // Elimination of division, int division, int modulus due to non-zero
+ // denominator. e.g. (not (= y 0)) => (div x y) ---> (div_total x y)
+ NON_ZERO_DEN,
+ // Elimination of int modulus due to range.
+ // e.g. (and (<= 0 x) (< x n)) => (mod x n) ---> x
+ INT_MOD_RANGE,
+ // e.g. (>= c 0) => (>= p 0) ---> true where c is inferred const lower bound
+ PRED_POS_LB,
+ // e.g. (= c 0) => (>= p 0) ---> true where c is inferred const lower bound
+ PRED_ZERO_LB,
+ // e.g. (> c 0) => (>= p 0) ---> false where c is inferred const upper bound
+ PRED_NEG_UB,
+
+ //-------------------------------------- NONE
+ NONE
+};
+
+/**
+ * Converts an learned rewrite id to a string.
+ *
+ * @param i The learned rewrite identifier
+ * @return The name of the learned rewrite identifier
+ */
+const char* toString(LearnedRewriteId i);
+
+/**
+ * Writes an learned rewrite identifier to a stream.
+ *
+ * @param out The stream to write to
+ * @param i The learned rewrite identifier to write to the stream
+ * @return The stream
+ */
+std::ostream& operator<<(std::ostream& out, LearnedRewriteId i);
+
+/**
+ * Applies learned rewriting. This rewrites the input based on learned literals.
+ * This in particular does rewriting that goes beyond what is done in
+ * non-clausal simplification, where equality substitutions + constant
+ * propagations are performed. In particular, this pass applies rewriting
+ * based on e.g. bound inference for arithmetic.
+ */
+class LearnedRewrite : public PreprocessingPass
+{
+ public:
+ LearnedRewrite(PreprocessingPassContext* preprocContext);
+
+ protected:
+ PreprocessingPassResult applyInternal(
+ AssertionPipeline* assertionsToPreprocess) override;
+ /**
+ * Apply rewrite with learned literals, traverses n.
+ */
+ Node rewriteLearnedRec(Node n,
+ theory::arith::BoundInference& binfer,
+ std::unordered_set<Node>& lems,
+ std::unordered_map<TNode, Node>& visited);
+ /**
+ * Learned rewrite to n, single step.
+ */
+ Node rewriteLearned(Node n,
+ theory::arith::BoundInference& binfer,
+ std::unordered_set<Node>& lems);
+ /** Return learned rewrite */
+ Node returnRewriteLearned(Node n, Node nr, LearnedRewriteId id);
+ /** Counts number of applications of learned rewrites */
+ HistogramStat<LearnedRewriteId> d_lrewCount;
+};
+
+} // namespace passes
+} // namespace preprocessing
+} // namespace cvc5
+
+#endif /* CVC5__PREPROCESSING__PASSES__LEARNED_REWRITE_H */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback