diff options
author | Andres Noetzli <andres.noetzli@gmail.com> | 2020-02-10 20:16:12 -0800 |
---|---|---|
committer | Andres Noetzli <andres.noetzli@gmail.com> | 2020-02-10 20:16:12 -0800 |
commit | f6fa60bfb45f8ade2a816f681911a01673968c1a (patch) | |
tree | fd73f4413dd1f08208f64f83aed17915246cd4bd /src/theory/rewriter.cpp | |
parent | aa18f9e6a3ef18071af3636871dc62c8ec0227b2 (diff) |
Introduce tables in the rewriter
This commit adds tables in the rewriter that store which function should
be used to rewrite which kind. We have separate tables for `EQUAL`
because every theory has its own equality rewriter.
Diffstat (limited to 'src/theory/rewriter.cpp')
-rw-r--r-- | src/theory/rewriter.cpp | 64 |
1 files changed, 58 insertions, 6 deletions
diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index 7a99ed2d9..da06b053c 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -81,6 +81,11 @@ struct RewriteStackElement { NodeBuilder<> builder; }; +RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n) +{ + return RewriteResponse(REWRITE_DONE, n); +} + Node Rewriter::rewrite(TNode node) { if (node.getNumChildren() == 0) { @@ -88,8 +93,33 @@ Node Rewriter::rewrite(TNode node) { // eagerly for the sake of efficiency here. return node; } - Rewriter& rewriter = getInstance(); - return rewriter.rewriteTo(theoryOf(node), node); + return getInstance().rewriteTo(theoryOf(node), node); +} + +void Rewriter::registerPreRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_preRewriters[k] = fn; +} + +void Rewriter::registerPostRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_postRewriters[k] = fn; +} + +void Rewriter::registerPreRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_preRewritersEqual[tid] = fn; +} + +void Rewriter::registerPostRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_postRewritersEqual[tid] = fn; } Rewriter& Rewriter::getInstance() @@ -153,8 +183,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { for(;;) { // Perform the pre-rewrite RewriteResponse response = - d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite( - rewriteStackTop.node); + preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.node); // Put the rewritten node to the top of the stack rewriteStackTop.node = response.node; TheoryId newTheory = theoryOf(rewriteStackTop.node); @@ -222,8 +251,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { for(;;) { // Do the post-rewrite RewriteResponse response = - d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite( - rewriteStackTop.node); + postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.node); // We continue with the response we got TheoryId newTheoryId = theoryOf(response.node); if (newTheoryId != rewriteStackTop.getTheoryId() @@ -286,6 +314,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { Unreachable(); }/* Rewriter::rewriteTo() */ +RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n) +{ + Kind k = n.getKind(); + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn = + (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k]; + if (fn == nullptr) + { + return d_theoryRewriters[theoryId]->preRewrite(n); + } + return fn(&d_re, n); +} + +RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n) +{ + Kind k = n.getKind(); + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn = + (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k]; + if (fn == nullptr) + { + return d_theoryRewriters[theoryId]->postRewrite(n); + } + return fn(&d_re, n); +} + void Rewriter::clearCaches() { Rewriter& rewriter = getInstance(); |