diff options
Diffstat (limited to 'src/theory/rewriter.cpp')
-rw-r--r-- | src/theory/rewriter.cpp | 103 |
1 files changed, 47 insertions, 56 deletions
diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index bcd095265..5c4cc5536 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -21,6 +21,8 @@ #include "smt/smt_engine_scope.h" #include "smt/smt_statistics_registry.h" #include "theory/builtin/proof_checker.h" +#include "theory/evaluator.h" +#include "theory/quantifiers/extended_rewrite.h" #include "theory/rewriter_tables.h" #include "theory/theory.h" #include "util/resource_manager.h" @@ -150,34 +152,6 @@ void Rewriter::registerTheoryRewriter(theory::TheoryId tid, d_theoryRewriters[tid] = trew; } -void Rewriter::registerPreRewrite( - Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) -{ - Assert(k != kind::EQUAL) << "Register pre-rewrites for EQUAL with registerPreRewriteEqual."; - d_preRewriters[k] = fn; -} - -void Rewriter::registerPostRewrite( - Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) -{ - Assert(k != kind::EQUAL) << "Register post-rewrites for EQUAL with registerPostRewriteEqual."; - d_postRewriters[k] = fn; -} - -void Rewriter::registerPreRewriteEqual( - theory::TheoryId tid, - std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) -{ - d_preRewritersEqual[tid] = fn; -} - -void Rewriter::registerPostRewriteEqual( - theory::TheoryId tid, - std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) -{ - d_postRewritersEqual[tid] = fn; -} - TheoryRewriter* Rewriter::getTheoryRewriter(theory::TheoryId theoryId) { return d_theoryRewriters[theoryId]; @@ -428,44 +402,30 @@ RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n, TConvProofGenerator* tcpg) { - Kind k = n.getKind(); - std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn = - (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k]; - if (fn == nullptr) + if (tcpg != nullptr) { - if (tcpg != nullptr) - { - // call the trust rewrite response interface - TrustRewriteResponse tresponse = - d_theoryRewriters[theoryId]->preRewriteWithProof(n); - // process the trust rewrite response: store the proof step into - // tcpg if necessary and then convert to rewrite response. - return processTrustRewriteResponse(theoryId, tresponse, true, tcpg); - } - return d_theoryRewriters[theoryId]->preRewrite(n); + // call the trust rewrite response interface + TrustRewriteResponse tresponse = + d_theoryRewriters[theoryId]->preRewriteWithProof(n); + // process the trust rewrite response: store the proof step into + // tcpg if necessary and then convert to rewrite response. + return processTrustRewriteResponse(theoryId, tresponse, true, tcpg); } - return fn(&d_re, n); + return d_theoryRewriters[theoryId]->preRewrite(n); } RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n, TConvProofGenerator* tcpg) { - Kind k = n.getKind(); - std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn = - (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k]; - if (fn == nullptr) + if (tcpg != nullptr) { - if (tcpg != nullptr) - { - // same as above, for post-rewrite - TrustRewriteResponse tresponse = - d_theoryRewriters[theoryId]->postRewriteWithProof(n); - return processTrustRewriteResponse(theoryId, tresponse, false, tcpg); - } - return d_theoryRewriters[theoryId]->postRewrite(n); + // same as above, for post-rewrite + TrustRewriteResponse tresponse = + d_theoryRewriters[theoryId]->postRewriteWithProof(n); + return processTrustRewriteResponse(theoryId, tresponse, false, tcpg); } - return fn(&d_re, n); + return d_theoryRewriters[theoryId]->postRewrite(n); } RewriteResponse Rewriter::processTrustRewriteResponse( @@ -512,5 +472,36 @@ void Rewriter::clearCaches() clearCachesInternal(); } +Node Rewriter::rewriteViaMethod(TNode n, MethodId idr) +{ + if (idr == MethodId::RW_REWRITE) + { + return rewrite(n); + } + if (idr == MethodId::RW_EXT_REWRITE) + { + quantifiers::ExtendedRewriter er; + return er.extendedRewrite(n); + } + if (idr == MethodId::RW_REWRITE_EQ_EXT) + { + return rewriteEqualityExt(n); + } + if (idr == MethodId::RW_EVALUATE) + { + Evaluator eval; + return eval.eval(n, {}, {}, false); + } + if (idr == MethodId::RW_IDENTITY) + { + // does nothing + return n; + } + // unknown rewriter + Unhandled() << "Rewriter::rewriteViaMethod: no rewriter for " << idr + << std::endl; + return n; +} + } // namespace theory } // namespace cvc5 |