summaryrefslogtreecommitdiff
path: root/src/theory/rewriter.cpp
diff options
context:
space:
mode:
authorAndres Noetzli <andres.noetzli@gmail.com>2020-02-10 20:16:12 -0800
committerAndres Noetzli <andres.noetzli@gmail.com>2020-02-10 20:16:12 -0800
commitf6fa60bfb45f8ade2a816f681911a01673968c1a (patch)
treefd73f4413dd1f08208f64f83aed17915246cd4bd /src/theory/rewriter.cpp
parentaa18f9e6a3ef18071af3636871dc62c8ec0227b2 (diff)
Introduce tables in the rewriter
This commit adds tables in the rewriter that store which function should be used to rewrite which kind. We have separate tables for `EQUAL` because every theory has its own equality rewriter.
Diffstat (limited to 'src/theory/rewriter.cpp')
-rw-r--r--src/theory/rewriter.cpp64
1 files changed, 58 insertions, 6 deletions
diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp
index 7a99ed2d9..da06b053c 100644
--- a/src/theory/rewriter.cpp
+++ b/src/theory/rewriter.cpp
@@ -81,6 +81,11 @@ struct RewriteStackElement {
NodeBuilder<> builder;
};
+RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n)
+{
+ return RewriteResponse(REWRITE_DONE, n);
+}
+
Node Rewriter::rewrite(TNode node) {
if (node.getNumChildren() == 0)
{
@@ -88,8 +93,33 @@ Node Rewriter::rewrite(TNode node) {
// eagerly for the sake of efficiency here.
return node;
}
- Rewriter& rewriter = getInstance();
- return rewriter.rewriteTo(theoryOf(node), node);
+ return getInstance().rewriteTo(theoryOf(node), node);
+}
+
+void Rewriter::registerPreRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_preRewriters[k] = fn;
+}
+
+void Rewriter::registerPostRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_postRewriters[k] = fn;
+}
+
+void Rewriter::registerPreRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_preRewritersEqual[tid] = fn;
+}
+
+void Rewriter::registerPostRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_postRewritersEqual[tid] = fn;
}
Rewriter& Rewriter::getInstance()
@@ -153,8 +183,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
for(;;) {
// Perform the pre-rewrite
RewriteResponse response =
- d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite(
- rewriteStackTop.node);
+ preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.node);
// Put the rewritten node to the top of the stack
rewriteStackTop.node = response.node;
TheoryId newTheory = theoryOf(rewriteStackTop.node);
@@ -222,8 +251,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
for(;;) {
// Do the post-rewrite
RewriteResponse response =
- d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite(
- rewriteStackTop.node);
+ postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.node);
// We continue with the response we got
TheoryId newTheoryId = theoryOf(response.node);
if (newTheoryId != rewriteStackTop.getTheoryId()
@@ -286,6 +314,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
Unreachable();
}/* Rewriter::rewriteTo() */
+RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n)
+{
+ Kind k = n.getKind();
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+ (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k];
+ if (fn == nullptr)
+ {
+ return d_theoryRewriters[theoryId]->preRewrite(n);
+ }
+ return fn(&d_re, n);
+}
+
+RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n)
+{
+ Kind k = n.getKind();
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+ (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k];
+ if (fn == nullptr)
+ {
+ return d_theoryRewriters[theoryId]->postRewrite(n);
+ }
+ return fn(&d_re, n);
+}
+
void Rewriter::clearCaches() {
Rewriter& rewriter = getInstance();
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback