summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAndres Noetzli <andres.noetzli@gmail.com>2020-02-10 20:16:12 -0800
committerAndres Noetzli <andres.noetzli@gmail.com>2020-02-10 20:16:12 -0800
commitf6fa60bfb45f8ade2a816f681911a01673968c1a (patch)
treefd73f4413dd1f08208f64f83aed17915246cd4bd /src
parentaa18f9e6a3ef18071af3636871dc62c8ec0227b2 (diff)
Introduce tables in the rewriter
This commit adds tables in the rewriter that store which function should be used to rewrite which kind. We have separate tables for `EQUAL` because every theory has its own equality rewriter.
Diffstat (limited to 'src')
-rw-r--r--src/theory/rewriter.cpp64
-rw-r--r--src/theory/rewriter.h80
-rw-r--r--src/theory/rewriter_tables_template.h13
-rw-r--r--src/theory/theory_rewriter.h9
4 files changed, 158 insertions, 8 deletions
diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp
index 7a99ed2d9..da06b053c 100644
--- a/src/theory/rewriter.cpp
+++ b/src/theory/rewriter.cpp
@@ -81,6 +81,11 @@ struct RewriteStackElement {
NodeBuilder<> builder;
};
+RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n)
+{
+ return RewriteResponse(REWRITE_DONE, n);
+}
+
Node Rewriter::rewrite(TNode node) {
if (node.getNumChildren() == 0)
{
@@ -88,8 +93,33 @@ Node Rewriter::rewrite(TNode node) {
// eagerly for the sake of efficiency here.
return node;
}
- Rewriter& rewriter = getInstance();
- return rewriter.rewriteTo(theoryOf(node), node);
+ return getInstance().rewriteTo(theoryOf(node), node);
+}
+
+void Rewriter::registerPreRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_preRewriters[k] = fn;
+}
+
+void Rewriter::registerPostRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_postRewriters[k] = fn;
+}
+
+void Rewriter::registerPreRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_preRewritersEqual[tid] = fn;
+}
+
+void Rewriter::registerPostRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_postRewritersEqual[tid] = fn;
}
Rewriter& Rewriter::getInstance()
@@ -153,8 +183,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
for(;;) {
// Perform the pre-rewrite
RewriteResponse response =
- d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite(
- rewriteStackTop.node);
+ preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.node);
// Put the rewritten node to the top of the stack
rewriteStackTop.node = response.node;
TheoryId newTheory = theoryOf(rewriteStackTop.node);
@@ -222,8 +251,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
for(;;) {
// Do the post-rewrite
RewriteResponse response =
- d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite(
- rewriteStackTop.node);
+ postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.node);
// We continue with the response we got
TheoryId newTheoryId = theoryOf(response.node);
if (newTheoryId != rewriteStackTop.getTheoryId()
@@ -286,6 +314,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
Unreachable();
}/* Rewriter::rewriteTo() */
+RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n)
+{
+ Kind k = n.getKind();
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+ (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k];
+ if (fn == nullptr)
+ {
+ return d_theoryRewriters[theoryId]->preRewrite(n);
+ }
+ return fn(&d_re, n);
+}
+
+RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n)
+{
+ Kind k = n.getKind();
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+ (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k];
+ if (fn == nullptr)
+ {
+ return d_theoryRewriters[theoryId]->postRewrite(n);
+ }
+ return fn(&d_re, n);
+}
+
void Rewriter::clearCaches() {
Rewriter& rewriter = getInstance();
diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h
index e55ca5d1c..f7298e1fb 100644
--- a/src/theory/rewriter.h
+++ b/src/theory/rewriter.h
@@ -28,6 +28,23 @@ namespace theory {
class RewriterInitializer;
/**
+ * The rewrite environment holds everything that the individual rewrites have
+ * access to.
+ */
+class RewriteEnvironment
+{
+};
+
+/**
+ * The identity rewrite just returns the original node.
+ *
+ * @param re The rewrite environment
+ * @param n The node to rewrite
+ * @return The original node
+ */
+RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n);
+
+/**
* The main rewriter class.
*/
class Rewriter {
@@ -45,6 +62,44 @@ class Rewriter {
*/
static void clearCaches();
+ /**
+ * Register a prerewrite for a given kind.
+ *
+ * @param k The kind to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPreRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+ /**
+ * Register a postrewrite for a given kind.
+ *
+ * @param k The kind to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPostRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+ /**
+ * Register a prerewrite for equalities belonging to a given theory.
+ *
+ * @param tid The theory to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPreRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+ /**
+ * Register a postrewrite for equalities belonging to a given theory.
+ *
+ * @param tid The theory to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPostRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
private:
/**
* Get the (singleton) instance of the rewriter.
@@ -71,10 +126,10 @@ class Rewriter {
Node rewriteTo(theory::TheoryId theoryId, Node node);
/** Calls the pre-rewriter for the given theory */
- RewriteResponse callPreRewrite(theory::TheoryId theoryId, TNode node);
+ RewriteResponse preRewrite(theory::TheoryId theoryId, TNode n);
/** Calls the post-rewriter for the given theory */
- RewriteResponse callPostRewrite(theory::TheoryId theoryId, TNode node);
+ RewriteResponse postRewrite(theory::TheoryId theoryId, TNode n);
/**
* Calls the equality-rewriter for the given theory.
@@ -88,6 +143,27 @@ class Rewriter {
unsigned long d_iterationCount = 0;
+ /** Rewriter table for prewrites. Maps kinds to rewriter function. */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_preRewriters[kind::LAST_KIND];
+ /** Rewriter table for postrewrites. Maps kinds to rewriter function. */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_postRewriters[kind::LAST_KIND];
+ /**
+ * Rewriter table for prerewrites of equalities. Maps theory to rewriter
+ * function.
+ */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_preRewritersEqual[theory::THEORY_LAST];
+ /**
+ * Rewriter table for postrewrites of equalities. Maps theory to rewriter
+ * function.
+ */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_postRewritersEqual[theory::THEORY_LAST];
+
+ RewriteEnvironment d_re;
+
#ifdef CVC4_ASSERTIONS
std::unique_ptr<std::unordered_set<Node, NodeHashFunction>> d_rewriteStack =
nullptr;
diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h
index e1be6355b..1bb03e253 100644
--- a/src/theory/rewriter_tables_template.h
+++ b/src/theory/rewriter_tables_template.h
@@ -64,6 +64,19 @@ ${post_rewrite_set_cache}
Rewriter::Rewriter()
{
${rewrite_init}
+
+for (size_t i = 0; i < kind::LAST_KIND; ++i)
+{
+ d_preRewriters[i] = nullptr;
+ d_postRewriters[i] = nullptr;
+}
+
+for (size_t i = 0; i < theory::THEORY_LAST; ++i)
+{
+ d_preRewritersEqual[i] = nullptr;
+ d_postRewritersEqual[i] = nullptr;
+ d_theoryRewriters[i]->registerRewrites(this);
+}
}
void Rewriter::clearCachesInternal() {
diff --git a/src/theory/theory_rewriter.h b/src/theory/theory_rewriter.h
index 61f0fc27a..93e03123b 100644
--- a/src/theory/theory_rewriter.h
+++ b/src/theory/theory_rewriter.h
@@ -24,6 +24,8 @@
namespace CVC4 {
namespace theory {
+class Rewriter;
+
/**
* Theory rewriters signal whether more rewriting is needed (or not)
* by using a member of this enumeration. See RewriteResponse, below.
@@ -66,6 +68,13 @@ class TheoryRewriter
virtual ~TheoryRewriter() = default;
/**
+ * Registers the rewrites of a given theory with the rewriter.
+ *
+ * @param rewriter The rewriter to register the rewrites with.
+ */
+ virtual void registerRewrites(Rewriter* rewriter) {}
+
+ /**
* Performs a pre-rewrite step.
*
* @param node The node to rewrite
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback