diff options
author | Andres Noetzli <andres.noetzli@gmail.com> | 2020-03-11 06:54:50 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-11 06:54:50 -0700 |
commit | 2b355305ef635ddfaad7fe75c29221cb2f744a62 (patch) | |
tree | 1667cc362fcf2f770bf7a47b81d887f648c8a857 /src/theory/rewriter.cpp | |
parent | edcc81b08b4d6c67da81b7ba2fcefbab3286f02c (diff) |
Introduce tables in the rewriter (#3742)
This commit adds tables in the rewriter that store which function should
be used to rewrite which kind. We have separate tables for `EQUAL`
because every theory has its own equality rewriter.
Diffstat (limited to 'src/theory/rewriter.cpp')
-rw-r--r-- | src/theory/rewriter.cpp | 66 |
1 files changed, 60 insertions, 6 deletions
diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index 765c2b4c8..b3f1e23d7 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -81,6 +81,11 @@ struct RewriteStackElement { NodeBuilder<> d_builder; }; +RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n) +{ + return RewriteResponse(REWRITE_DONE, n); +} + Node Rewriter::rewrite(TNode node) { if (node.getNumChildren() == 0) { @@ -88,8 +93,35 @@ Node Rewriter::rewrite(TNode node) { // eagerly for the sake of efficiency here. return node; } - Rewriter& rewriter = getInstance(); - return rewriter.rewriteTo(theoryOf(node), node); + return getInstance().rewriteTo(theoryOf(node), node); +} + +void Rewriter::registerPreRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + Assert(k != kind::EQUAL) << "Register pre-rewrites for EQUAL with registerPreRewriteEqual."; + d_preRewriters[k] = fn; +} + +void Rewriter::registerPostRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + Assert(k != kind::EQUAL) << "Register post-rewrites for EQUAL with registerPostRewriteEqual."; + d_postRewriters[k] = fn; +} + +void Rewriter::registerPreRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_preRewritersEqual[tid] = fn; +} + +void Rewriter::registerPostRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_postRewritersEqual[tid] = fn; } Rewriter& Rewriter::getInstance() @@ -153,8 +185,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { for(;;) { // Perform the pre-rewrite RewriteResponse response = - d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite( - rewriteStackTop.d_node); + preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node); // Put the rewritten node to the top of the stack rewriteStackTop.d_node = response.d_node; TheoryId newTheory = theoryOf(rewriteStackTop.d_node); @@ -225,8 +256,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { for(;;) { // Do the post-rewrite RewriteResponse response = - d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite( - rewriteStackTop.d_node); + postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node); // We continue with the response we got TheoryId newTheoryId = theoryOf(response.d_node); if (newTheoryId != rewriteStackTop.getTheoryId() @@ -290,6 +320,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { Unreachable(); }/* Rewriter::rewriteTo() */ +RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n) +{ + Kind k = n.getKind(); + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn = + (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k]; + if (fn == nullptr) + { + return d_theoryRewriters[theoryId]->preRewrite(n); + } + return fn(&d_re, n); +} + +RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n) +{ + Kind k = n.getKind(); + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn = + (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k]; + if (fn == nullptr) + { + return d_theoryRewriters[theoryId]->postRewrite(n); + } + return fn(&d_re, n); +} + void Rewriter::clearCaches() { Rewriter& rewriter = getInstance(); |