From bb94bfd729be285c80aae5e97f41f813848f40cb Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Tue, 6 Jul 2021 12:23:45 -0500 Subject: Add learned rewrite preprocessing pass (#6842) Adds the basic skeleton of the pass. --- src/preprocessing/passes/learned_rewrite.cpp | 181 +++++++++++++++++++++++++++ src/preprocessing/passes/learned_rewrite.h | 108 ++++++++++++++++ 2 files changed, 289 insertions(+) create mode 100644 src/preprocessing/passes/learned_rewrite.cpp create mode 100644 src/preprocessing/passes/learned_rewrite.h diff --git a/src/preprocessing/passes/learned_rewrite.cpp b/src/preprocessing/passes/learned_rewrite.cpp new file mode 100644 index 000000000..785889666 --- /dev/null +++ b/src/preprocessing/passes/learned_rewrite.cpp @@ -0,0 +1,181 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Rewriting based on learned literals + */ + +#include "preprocessing/passes/learned_rewrite.h" + +#include "expr/skolem_manager.h" +#include "expr/term_context_stack.h" +#include "preprocessing/assertion_pipeline.h" +#include "smt/smt_statistics_registry.h" +#include "theory/arith/arith_msum.h" +#include "theory/rewriter.h" +#include "util/rational.h" + +using namespace cvc5::theory; +using namespace cvc5::kind; + +namespace cvc5 { +namespace preprocessing { +namespace passes { + +const char* toString(LearnedRewriteId i) +{ + switch (i) + { + case LearnedRewriteId::NON_ZERO_DEN: return "NON_ZERO_DEN"; + case LearnedRewriteId::INT_MOD_RANGE: return "INT_MOD_RANGE"; + case LearnedRewriteId::PRED_POS_LB: return "PRED_POS_LB"; + case LearnedRewriteId::PRED_ZERO_LB: return "PRED_ZERO_LB"; + case LearnedRewriteId::PRED_NEG_UB: return "PRED_NEG_UB"; + case LearnedRewriteId::NONE: return "NONE"; + default: return "?LearnedRewriteId?"; + } +} + +std::ostream& operator<<(std::ostream& out, LearnedRewriteId i) +{ + out << toString(i); + return out; +} + +LearnedRewrite::LearnedRewrite(PreprocessingPassContext* preprocContext) + : PreprocessingPass(preprocContext, "learned-rewrite"), + d_lrewCount(smtStatisticsRegistry().registerHistogram( + "LearnedRewrite::lrewCount")) +{ +} + +PreprocessingPassResult LearnedRewrite::applyInternal( + AssertionPipeline* assertionsToPreprocess) +{ + arith::BoundInference binfer; + std::vector learnedLits = d_preprocContext->getLearnedLiterals(); + std::unordered_set llrw; + std::unordered_map visited; + if (learnedLits.empty()) + { + Trace("learned-rewrite-ll") << "No learned literals" << std::endl; + return PreprocessingPassResult::NO_CONFLICT; + } + else + { + Trace("learned-rewrite-ll") << "Learned literals:" << std::endl; + for (const Node& l : learnedLits) + { + // maybe use the literal for bound inference? + Kind k = l.getKind(); + Assert (k != LT && k != GT && k != LEQ); + if (k == EQUAL || k == GEQ) + { + binfer.add(l); + } + Trace("learned-rewrite-ll") << "- " << l << std::endl; + } + const std::map& bs = binfer.get(); + // get the literals that were critical, i.e. used in the derivation of a + // bound + for (const std::pair& b : bs) + { + for (size_t i = 0; i < 2; i++) + { + Node origin = i == 0 ? b.second.lower_origin : b.second.upper_origin; + if (!origin.isNull()) + { + llrw.insert(origin); + } + } + } + // rewrite the non-critical learned literals, some may be redundant + for (const Node& l : learnedLits) + { + if (llrw.find(l) != llrw.end()) + { + continue; + } + Node e = rewriteLearnedRec(l, binfer, llrw, visited); + if (e.isConst()) + { + // ignore true + if (e.getConst()) + { + continue; + } + // conflict, we are done + assertionsToPreprocess->push_back(e); + return PreprocessingPassResult::CONFLICT; + } + llrw.insert(e); + } + Trace("learned-rewrite-ll") << "end" << std::endl; + } + size_t size = assertionsToPreprocess->size(); + for (size_t i = 0; i < size; ++i) + { + Node prev = (*assertionsToPreprocess)[i]; + Trace("learned-rewrite-assert") + << "LearnedRewrite: assert: " << prev << std::endl; + Node e = rewriteLearnedRec(prev, binfer, llrw, visited); + if (e != prev) + { + Trace("learned-rewrite-assert") + << ".......................: " << e << std::endl; + assertionsToPreprocess->replace(i, e); + } + } + // Add the conjunction of learned literals back to assertions. Notice that + // in some cases we may add top-level assertions back to the assertion list + // unchanged. + if (!llrw.empty()) + { + NodeManager* nm = NodeManager::currentNM(); + std::vector llrvec(llrw.begin(), llrw.end()); + Node llc = nm->mkAnd(llrvec); + Trace("learned-rewrite-assert") + << "Re-add rewritten learned conjunction: " << llc << std::endl; + assertionsToPreprocess->push_back(llc); + } + + return PreprocessingPassResult::NO_CONFLICT; +} + +Node LearnedRewrite::rewriteLearnedRec(Node n, + arith::BoundInference& binfer, + std::unordered_set& lems, + std::unordered_map& visited) +{ + return n; +} + +Node LearnedRewrite::rewriteLearned(Node n, + arith::BoundInference& binfer, + std::unordered_set& lems) +{ + return n; +} + +Node LearnedRewrite::returnRewriteLearned(Node n, Node nr, LearnedRewriteId id) +{ + if (Trace.isOn("learned-rewrite")) + { + Trace("learned-rewrite") << "LearnedRewrite::Rewrite: (" << id << ") " << n + << " == " << nr << std::endl; + } + d_lrewCount << id; + return nr; +} + +} // namespace passes +} // namespace preprocessing +} // namespace cvc5 diff --git a/src/preprocessing/passes/learned_rewrite.h b/src/preprocessing/passes/learned_rewrite.h new file mode 100644 index 000000000..4f3a51d58 --- /dev/null +++ b/src/preprocessing/passes/learned_rewrite.h @@ -0,0 +1,108 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Rewriting based on learned literals + */ +#include "cvc5_private.h" + +#ifndef CVC5__PREPROCESSING__PASSES__LEARNED_REWRITE_H +#define CVC5__PREPROCESSING__PASSES__LEARNED_REWRITE_H + +#include "preprocessing/preprocessing_pass.h" +#include "preprocessing/preprocessing_pass_context.h" +#include "theory/arith/bound_inference.h" +#include "util/statistics_stats.h" + +#include + +namespace cvc5 { +namespace preprocessing { +namespace passes { + +/** + * Learned rewrites in the pass below. + */ +enum class LearnedRewriteId +{ + // Elimination of division, int division, int modulus due to non-zero + // denominator. e.g. (not (= y 0)) => (div x y) ---> (div_total x y) + NON_ZERO_DEN, + // Elimination of int modulus due to range. + // e.g. (and (<= 0 x) (< x n)) => (mod x n) ---> x + INT_MOD_RANGE, + // e.g. (>= c 0) => (>= p 0) ---> true where c is inferred const lower bound + PRED_POS_LB, + // e.g. (= c 0) => (>= p 0) ---> true where c is inferred const lower bound + PRED_ZERO_LB, + // e.g. (> c 0) => (>= p 0) ---> false where c is inferred const upper bound + PRED_NEG_UB, + + //-------------------------------------- NONE + NONE +}; + +/** + * Converts an learned rewrite id to a string. + * + * @param i The learned rewrite identifier + * @return The name of the learned rewrite identifier + */ +const char* toString(LearnedRewriteId i); + +/** + * Writes an learned rewrite identifier to a stream. + * + * @param out The stream to write to + * @param i The learned rewrite identifier to write to the stream + * @return The stream + */ +std::ostream& operator<<(std::ostream& out, LearnedRewriteId i); + +/** + * Applies learned rewriting. This rewrites the input based on learned literals. + * This in particular does rewriting that goes beyond what is done in + * non-clausal simplification, where equality substitutions + constant + * propagations are performed. In particular, this pass applies rewriting + * based on e.g. bound inference for arithmetic. + */ +class LearnedRewrite : public PreprocessingPass +{ + public: + LearnedRewrite(PreprocessingPassContext* preprocContext); + + protected: + PreprocessingPassResult applyInternal( + AssertionPipeline* assertionsToPreprocess) override; + /** + * Apply rewrite with learned literals, traverses n. + */ + Node rewriteLearnedRec(Node n, + theory::arith::BoundInference& binfer, + std::unordered_set& lems, + std::unordered_map& visited); + /** + * Learned rewrite to n, single step. + */ + Node rewriteLearned(Node n, + theory::arith::BoundInference& binfer, + std::unordered_set& lems); + /** Return learned rewrite */ + Node returnRewriteLearned(Node n, Node nr, LearnedRewriteId id); + /** Counts number of applications of learned rewrites */ + HistogramStat d_lrewCount; +}; + +} // namespace passes +} // namespace preprocessing +} // namespace cvc5 + +#endif /* CVC5__PREPROCESSING__PASSES__LEARNED_REWRITE_H */ -- cgit v1.2.3