diff options
author | Andres Noetzli <andres.noetzli@gmail.com> | 2020-02-10 20:16:12 -0800 |
---|---|---|
committer | Andres Noetzli <andres.noetzli@gmail.com> | 2020-02-10 20:16:12 -0800 |
commit | f6fa60bfb45f8ade2a816f681911a01673968c1a (patch) | |
tree | fd73f4413dd1f08208f64f83aed17915246cd4bd /src | |
parent | aa18f9e6a3ef18071af3636871dc62c8ec0227b2 (diff) |
Introduce tables in the rewriter
This commit adds tables in the rewriter that store which function should
be used to rewrite which kind. We have separate tables for `EQUAL`
because every theory has its own equality rewriter.
Diffstat (limited to 'src')
-rw-r--r-- | src/theory/rewriter.cpp | 64 | ||||
-rw-r--r-- | src/theory/rewriter.h | 80 | ||||
-rw-r--r-- | src/theory/rewriter_tables_template.h | 13 | ||||
-rw-r--r-- | src/theory/theory_rewriter.h | 9 |
4 files changed, 158 insertions, 8 deletions
diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index 7a99ed2d9..da06b053c 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -81,6 +81,11 @@ struct RewriteStackElement { NodeBuilder<> builder; }; +RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n) +{ + return RewriteResponse(REWRITE_DONE, n); +} + Node Rewriter::rewrite(TNode node) { if (node.getNumChildren() == 0) { @@ -88,8 +93,33 @@ Node Rewriter::rewrite(TNode node) { // eagerly for the sake of efficiency here. return node; } - Rewriter& rewriter = getInstance(); - return rewriter.rewriteTo(theoryOf(node), node); + return getInstance().rewriteTo(theoryOf(node), node); +} + +void Rewriter::registerPreRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_preRewriters[k] = fn; +} + +void Rewriter::registerPostRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_postRewriters[k] = fn; +} + +void Rewriter::registerPreRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_preRewritersEqual[tid] = fn; +} + +void Rewriter::registerPostRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_postRewritersEqual[tid] = fn; } Rewriter& Rewriter::getInstance() @@ -153,8 +183,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { for(;;) { // Perform the pre-rewrite RewriteResponse response = - d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite( - rewriteStackTop.node); + preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.node); // Put the rewritten node to the top of the stack rewriteStackTop.node = response.node; TheoryId newTheory = theoryOf(rewriteStackTop.node); @@ -222,8 +251,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { for(;;) { // Do the post-rewrite RewriteResponse response = - d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite( - rewriteStackTop.node); + postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.node); // We continue with the response we got TheoryId newTheoryId = theoryOf(response.node); if (newTheoryId != rewriteStackTop.getTheoryId() @@ -286,6 +314,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { Unreachable(); }/* Rewriter::rewriteTo() */ +RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n) +{ + Kind k = n.getKind(); + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn = + (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k]; + if (fn == nullptr) + { + return d_theoryRewriters[theoryId]->preRewrite(n); + } + return fn(&d_re, n); +} + +RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n) +{ + Kind k = n.getKind(); + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn = + (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k]; + if (fn == nullptr) + { + return d_theoryRewriters[theoryId]->postRewrite(n); + } + return fn(&d_re, n); +} + void Rewriter::clearCaches() { Rewriter& rewriter = getInstance(); diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h index e55ca5d1c..f7298e1fb 100644 --- a/src/theory/rewriter.h +++ b/src/theory/rewriter.h @@ -28,6 +28,23 @@ namespace theory { class RewriterInitializer; /** + * The rewrite environment holds everything that the individual rewrites have + * access to. + */ +class RewriteEnvironment +{ +}; + +/** + * The identity rewrite just returns the original node. + * + * @param re The rewrite environment + * @param n The node to rewrite + * @return The original node + */ +RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n); + +/** * The main rewriter class. */ class Rewriter { @@ -45,6 +62,44 @@ class Rewriter { */ static void clearCaches(); + /** + * Register a prerewrite for a given kind. + * + * @param k The kind to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPreRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn); + + /** + * Register a postrewrite for a given kind. + * + * @param k The kind to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPostRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn); + + /** + * Register a prerewrite for equalities belonging to a given theory. + * + * @param tid The theory to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPreRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn); + + /** + * Register a postrewrite for equalities belonging to a given theory. + * + * @param tid The theory to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPostRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn); + private: /** * Get the (singleton) instance of the rewriter. @@ -71,10 +126,10 @@ class Rewriter { Node rewriteTo(theory::TheoryId theoryId, Node node); /** Calls the pre-rewriter for the given theory */ - RewriteResponse callPreRewrite(theory::TheoryId theoryId, TNode node); + RewriteResponse preRewrite(theory::TheoryId theoryId, TNode n); /** Calls the post-rewriter for the given theory */ - RewriteResponse callPostRewrite(theory::TheoryId theoryId, TNode node); + RewriteResponse postRewrite(theory::TheoryId theoryId, TNode n); /** * Calls the equality-rewriter for the given theory. @@ -88,6 +143,27 @@ class Rewriter { unsigned long d_iterationCount = 0; + /** Rewriter table for prewrites. Maps kinds to rewriter function. */ + std::function<RewriteResponse(RewriteEnvironment*, TNode)> + d_preRewriters[kind::LAST_KIND]; + /** Rewriter table for postrewrites. Maps kinds to rewriter function. */ + std::function<RewriteResponse(RewriteEnvironment*, TNode)> + d_postRewriters[kind::LAST_KIND]; + /** + * Rewriter table for prerewrites of equalities. Maps theory to rewriter + * function. + */ + std::function<RewriteResponse(RewriteEnvironment*, TNode)> + d_preRewritersEqual[theory::THEORY_LAST]; + /** + * Rewriter table for postrewrites of equalities. Maps theory to rewriter + * function. + */ + std::function<RewriteResponse(RewriteEnvironment*, TNode)> + d_postRewritersEqual[theory::THEORY_LAST]; + + RewriteEnvironment d_re; + #ifdef CVC4_ASSERTIONS std::unique_ptr<std::unordered_set<Node, NodeHashFunction>> d_rewriteStack = nullptr; diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h index e1be6355b..1bb03e253 100644 --- a/src/theory/rewriter_tables_template.h +++ b/src/theory/rewriter_tables_template.h @@ -64,6 +64,19 @@ ${post_rewrite_set_cache} Rewriter::Rewriter() { ${rewrite_init} + +for (size_t i = 0; i < kind::LAST_KIND; ++i) +{ + d_preRewriters[i] = nullptr; + d_postRewriters[i] = nullptr; +} + +for (size_t i = 0; i < theory::THEORY_LAST; ++i) +{ + d_preRewritersEqual[i] = nullptr; + d_postRewritersEqual[i] = nullptr; + d_theoryRewriters[i]->registerRewrites(this); +} } void Rewriter::clearCachesInternal() { diff --git a/src/theory/theory_rewriter.h b/src/theory/theory_rewriter.h index 61f0fc27a..93e03123b 100644 --- a/src/theory/theory_rewriter.h +++ b/src/theory/theory_rewriter.h @@ -24,6 +24,8 @@ namespace CVC4 { namespace theory { +class Rewriter; + /** * Theory rewriters signal whether more rewriting is needed (or not) * by using a member of this enumeration. See RewriteResponse, below. @@ -66,6 +68,13 @@ class TheoryRewriter virtual ~TheoryRewriter() = default; /** + * Registers the rewrites of a given theory with the rewriter. + * + * @param rewriter The rewriter to register the rewrites with. + */ + virtual void registerRewrites(Rewriter* rewriter) {} + + /** * Performs a pre-rewrite step. * * @param node The node to rewrite |