From 8b53a48ce6041b98e3761c2a341f727bcaaf2686 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 3 Sep 2021 13:42:18 -0500 Subject: Standardize Rewriter::rewriteViaMethod call (#7119) This moves the standard method for rewrites in proofs from TheoryBuiltinProofRuleChecker to Rewriter. The motivation for this change is to make various kinds of rewrite methods (standard rewrite, extended rewrite, extended equality rewrite, evaluate) accessible throughout the code in a standard way. After this PR, it is possible to know variants of the REWRITE proof rule application by having access to the rewriter, instead of having to get the builtin proof rule checker. Note that TheoryBuiltinProofRuleChecker::applyRewrite *cannot* be static since access to the rewriter is not longer permitted to be static. It also removes some unused infrastructure from Rewriter. Followup PRs will remove applyRewrite for TheoryBuiltinProofRuleChecker in favor of calling the rewriter directly. --- src/theory/builtin/proof_checker.cpp | 31 +--------- src/theory/builtin/proof_checker.h | 2 +- src/theory/rewriter.cpp | 103 ++++++++++++++++------------------ src/theory/rewriter.h | 65 ++++----------------- src/theory/rewriter_tables_template.h | 10 ---- 5 files changed, 60 insertions(+), 151 deletions(-) diff --git a/src/theory/builtin/proof_checker.cpp b/src/theory/builtin/proof_checker.cpp index bb0f9a413..e51db4ce3 100644 --- a/src/theory/builtin/proof_checker.cpp +++ b/src/theory/builtin/proof_checker.cpp @@ -70,36 +70,9 @@ Node BuiltinProofRuleChecker::applySubstitutionRewrite( return applyRewrite(nks, idr); } -Node BuiltinProofRuleChecker::applyRewrite(Node n, MethodId idr) +Node BuiltinProofRuleChecker::applyRewrite(TNode n, MethodId idr) { - Trace("builtin-pfcheck-debug") - << "applyRewrite (" << idr << "): " << n << std::endl; - if (idr == MethodId::RW_REWRITE) - { - return Rewriter::rewrite(n); - } - if (idr == MethodId::RW_EXT_REWRITE) - { - return d_ext_rewriter.extendedRewrite(n); - } - if (idr == MethodId::RW_REWRITE_EQ_EXT) - { - return d_env.getRewriter()->rewriteEqualityExt(n); - } - if (idr == MethodId::RW_EVALUATE) - { - Evaluator eval; - return eval.eval(n, {}, {}, false); - } - if (idr == MethodId::RW_IDENTITY) - { - // does nothing - return n; - } - // unknown rewriter - Assert(false) << "BuiltinProofRuleChecker::applyRewrite: no rewriter for " - << idr << std::endl; - return n; + return d_env.getRewriter()->rewriteViaMethod(n, idr); } bool BuiltinProofRuleChecker::getSubstitutionForLit(Node exp, diff --git a/src/theory/builtin/proof_checker.h b/src/theory/builtin/proof_checker.h index d7edd2c53..bb746e467 100644 --- a/src/theory/builtin/proof_checker.h +++ b/src/theory/builtin/proof_checker.h @@ -48,7 +48,7 @@ class BuiltinProofRuleChecker : public ProofRuleChecker * specifying a call to Rewriter::rewrite. * @return The rewritten form of n. */ - Node applyRewrite(Node n, MethodId idr = MethodId::RW_REWRITE); + Node applyRewrite(TNode n, MethodId idr = MethodId::RW_REWRITE); /** * Get substitution for literal exp. Updates vars/subs to the substitution * specified by exp for the substitution method ids. diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index bcd095265..5c4cc5536 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -21,6 +21,8 @@ #include "smt/smt_engine_scope.h" #include "smt/smt_statistics_registry.h" #include "theory/builtin/proof_checker.h" +#include "theory/evaluator.h" +#include "theory/quantifiers/extended_rewrite.h" #include "theory/rewriter_tables.h" #include "theory/theory.h" #include "util/resource_manager.h" @@ -150,34 +152,6 @@ void Rewriter::registerTheoryRewriter(theory::TheoryId tid, d_theoryRewriters[tid] = trew; } -void Rewriter::registerPreRewrite( - Kind k, std::function fn) -{ - Assert(k != kind::EQUAL) << "Register pre-rewrites for EQUAL with registerPreRewriteEqual."; - d_preRewriters[k] = fn; -} - -void Rewriter::registerPostRewrite( - Kind k, std::function fn) -{ - Assert(k != kind::EQUAL) << "Register post-rewrites for EQUAL with registerPostRewriteEqual."; - d_postRewriters[k] = fn; -} - -void Rewriter::registerPreRewriteEqual( - theory::TheoryId tid, - std::function fn) -{ - d_preRewritersEqual[tid] = fn; -} - -void Rewriter::registerPostRewriteEqual( - theory::TheoryId tid, - std::function fn) -{ - d_postRewritersEqual[tid] = fn; -} - TheoryRewriter* Rewriter::getTheoryRewriter(theory::TheoryId theoryId) { return d_theoryRewriters[theoryId]; @@ -428,44 +402,30 @@ RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n, TConvProofGenerator* tcpg) { - Kind k = n.getKind(); - std::function fn = - (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k]; - if (fn == nullptr) + if (tcpg != nullptr) { - if (tcpg != nullptr) - { - // call the trust rewrite response interface - TrustRewriteResponse tresponse = - d_theoryRewriters[theoryId]->preRewriteWithProof(n); - // process the trust rewrite response: store the proof step into - // tcpg if necessary and then convert to rewrite response. - return processTrustRewriteResponse(theoryId, tresponse, true, tcpg); - } - return d_theoryRewriters[theoryId]->preRewrite(n); + // call the trust rewrite response interface + TrustRewriteResponse tresponse = + d_theoryRewriters[theoryId]->preRewriteWithProof(n); + // process the trust rewrite response: store the proof step into + // tcpg if necessary and then convert to rewrite response. + return processTrustRewriteResponse(theoryId, tresponse, true, tcpg); } - return fn(&d_re, n); + return d_theoryRewriters[theoryId]->preRewrite(n); } RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n, TConvProofGenerator* tcpg) { - Kind k = n.getKind(); - std::function fn = - (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k]; - if (fn == nullptr) + if (tcpg != nullptr) { - if (tcpg != nullptr) - { - // same as above, for post-rewrite - TrustRewriteResponse tresponse = - d_theoryRewriters[theoryId]->postRewriteWithProof(n); - return processTrustRewriteResponse(theoryId, tresponse, false, tcpg); - } - return d_theoryRewriters[theoryId]->postRewrite(n); + // same as above, for post-rewrite + TrustRewriteResponse tresponse = + d_theoryRewriters[theoryId]->postRewriteWithProof(n); + return processTrustRewriteResponse(theoryId, tresponse, false, tcpg); } - return fn(&d_re, n); + return d_theoryRewriters[theoryId]->postRewrite(n); } RewriteResponse Rewriter::processTrustRewriteResponse( @@ -512,5 +472,36 @@ void Rewriter::clearCaches() clearCachesInternal(); } +Node Rewriter::rewriteViaMethod(TNode n, MethodId idr) +{ + if (idr == MethodId::RW_REWRITE) + { + return rewrite(n); + } + if (idr == MethodId::RW_EXT_REWRITE) + { + quantifiers::ExtendedRewriter er; + return er.extendedRewrite(n); + } + if (idr == MethodId::RW_REWRITE_EQ_EXT) + { + return rewriteEqualityExt(n); + } + if (idr == MethodId::RW_EVALUATE) + { + Evaluator eval; + return eval.eval(n, {}, {}, false); + } + if (idr == MethodId::RW_IDENTITY) + { + // does nothing + return n; + } + // unknown rewriter + Unhandled() << "Rewriter::rewriteViaMethod: no rewriter for " << idr + << std::endl; + return n; +} + } // namespace theory } // namespace cvc5 diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h index 23a9914bd..63628b0af 100644 --- a/src/theory/rewriter.h +++ b/src/theory/rewriter.h @@ -18,6 +18,7 @@ #pragma once #include "expr/node.h" +#include "proof/method_id.h" #include "theory/theory_rewriter.h" namespace cvc5 { @@ -105,46 +106,19 @@ class Rewriter { */ void registerTheoryRewriter(theory::TheoryId tid, TheoryRewriter* trew); - /** - * Register a prerewrite for a given kind. - * - * @param k The kind to register a rewrite for. - * @param fn The function that performs the rewrite. - */ - void registerPreRewrite( - Kind k, std::function fn); - - /** - * Register a postrewrite for a given kind. - * - * @param k The kind to register a rewrite for. - * @param fn The function that performs the rewrite. - */ - void registerPostRewrite( - Kind k, std::function fn); - - /** - * Register a prerewrite for equalities belonging to a given theory. - * - * @param tid The theory to register a rewrite for. - * @param fn The function that performs the rewrite. - */ - void registerPreRewriteEqual( - theory::TheoryId tid, - std::function fn); + /** Get the theory rewriter for the given id */ + TheoryRewriter* getTheoryRewriter(theory::TheoryId theoryId); /** - * Register a postrewrite for equalities belonging to a given theory. + * Apply rewrite on n via the rewrite method identifier idr (see method_id.h). + * This encapsulates the exact behavior of a REWRITE step in a proof. * - * @param tid The theory to register a rewrite for. - * @param fn The function that performs the rewrite. + * @param n The node to rewrite, + * @param idr The method identifier of the rewriter, by default RW_REWRITE + * specifying a call to rewrite. + * @return The rewritten form of n. */ - void registerPostRewriteEqual( - theory::TheoryId tid, - std::function fn); - - /** Get the theory rewriter for the given id */ - TheoryRewriter* getTheoryRewriter(theory::TheoryId theoryId); + Node rewriteViaMethod(TNode n, MethodId idr = MethodId::RW_REWRITE); private: /** @@ -200,25 +174,6 @@ class Rewriter { /** Theory rewriters used by this rewriter instance */ TheoryRewriter* d_theoryRewriters[theory::THEORY_LAST]; - /** Rewriter table for prewrites. Maps kinds to rewriter function. */ - std::function - d_preRewriters[kind::LAST_KIND]; - /** Rewriter table for postrewrites. Maps kinds to rewriter function. */ - std::function - d_postRewriters[kind::LAST_KIND]; - /** - * Rewriter table for prerewrites of equalities. Maps theory to rewriter - * function. - */ - std::function - d_preRewritersEqual[theory::THEORY_LAST]; - /** - * Rewriter table for postrewrites of equalities. Maps theory to rewriter - * function. - */ - std::function - d_postRewritersEqual[theory::THEORY_LAST]; - RewriteEnvironment d_re; /** The proof generator */ diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h index c549f8cfb..36d320fb7 100644 --- a/src/theory/rewriter_tables_template.h +++ b/src/theory/rewriter_tables_template.h @@ -82,17 +82,7 @@ ${post_rewrite_set_cache} Rewriter::Rewriter() : d_tpg(nullptr) { - for (size_t i = 0; i < kind::LAST_KIND; ++i) - { - d_preRewriters[i] = nullptr; - d_postRewriters[i] = nullptr; - } - for (size_t i = 0; i < theory::THEORY_LAST; ++i) - { - d_preRewritersEqual[i] = nullptr; - d_postRewritersEqual[i] = nullptr; - } } void Rewriter::clearCachesInternal() -- cgit v1.2.3