summaryrefslogtreecommitdiff
path: root/src/theory/rewriter.cpp
diff options
context:
space:
mode:
authorAndres Noetzli <andres.noetzli@gmail.com>2020-03-11 06:54:50 -0700
committerGitHub <noreply@github.com>2020-03-11 06:54:50 -0700
commit2b355305ef635ddfaad7fe75c29221cb2f744a62 (patch)
tree1667cc362fcf2f770bf7a47b81d887f648c8a857 /src/theory/rewriter.cpp
parentedcc81b08b4d6c67da81b7ba2fcefbab3286f02c (diff)
Introduce tables in the rewriter (#3742)
This commit adds tables in the rewriter that store which function should be used to rewrite which kind. We have separate tables for `EQUAL` because every theory has its own equality rewriter.
Diffstat (limited to 'src/theory/rewriter.cpp')
-rw-r--r--src/theory/rewriter.cpp66
1 files changed, 60 insertions, 6 deletions
diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp
index 765c2b4c8..b3f1e23d7 100644
--- a/src/theory/rewriter.cpp
+++ b/src/theory/rewriter.cpp
@@ -81,6 +81,11 @@ struct RewriteStackElement {
NodeBuilder<> d_builder;
};
+RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n)
+{
+ return RewriteResponse(REWRITE_DONE, n);
+}
+
Node Rewriter::rewrite(TNode node) {
if (node.getNumChildren() == 0)
{
@@ -88,8 +93,35 @@ Node Rewriter::rewrite(TNode node) {
// eagerly for the sake of efficiency here.
return node;
}
- Rewriter& rewriter = getInstance();
- return rewriter.rewriteTo(theoryOf(node), node);
+ return getInstance().rewriteTo(theoryOf(node), node);
+}
+
+void Rewriter::registerPreRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ Assert(k != kind::EQUAL) << "Register pre-rewrites for EQUAL with registerPreRewriteEqual.";
+ d_preRewriters[k] = fn;
+}
+
+void Rewriter::registerPostRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ Assert(k != kind::EQUAL) << "Register post-rewrites for EQUAL with registerPostRewriteEqual.";
+ d_postRewriters[k] = fn;
+}
+
+void Rewriter::registerPreRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_preRewritersEqual[tid] = fn;
+}
+
+void Rewriter::registerPostRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_postRewritersEqual[tid] = fn;
}
Rewriter& Rewriter::getInstance()
@@ -153,8 +185,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
for(;;) {
// Perform the pre-rewrite
RewriteResponse response =
- d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite(
- rewriteStackTop.d_node);
+ preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node);
// Put the rewritten node to the top of the stack
rewriteStackTop.d_node = response.d_node;
TheoryId newTheory = theoryOf(rewriteStackTop.d_node);
@@ -225,8 +256,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
for(;;) {
// Do the post-rewrite
RewriteResponse response =
- d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite(
- rewriteStackTop.d_node);
+ postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node);
// We continue with the response we got
TheoryId newTheoryId = theoryOf(response.d_node);
if (newTheoryId != rewriteStackTop.getTheoryId()
@@ -290,6 +320,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
Unreachable();
}/* Rewriter::rewriteTo() */
+RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n)
+{
+ Kind k = n.getKind();
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+ (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k];
+ if (fn == nullptr)
+ {
+ return d_theoryRewriters[theoryId]->preRewrite(n);
+ }
+ return fn(&d_re, n);
+}
+
+RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n)
+{
+ Kind k = n.getKind();
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+ (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k];
+ if (fn == nullptr)
+ {
+ return d_theoryRewriters[theoryId]->postRewrite(n);
+ }
+ return fn(&d_re, n);
+}
+
void Rewriter::clearCaches() {
Rewriter& rewriter = getInstance();
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback