summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/theory/rewriter.cpp66
-rw-r--r--src/theory/rewriter.h80
-rw-r--r--src/theory/rewriter_tables_template.h13
-rw-r--r--src/theory/theory_rewriter.h9
4 files changed, 160 insertions, 8 deletions
diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp
index 765c2b4c8..b3f1e23d7 100644
--- a/src/theory/rewriter.cpp
+++ b/src/theory/rewriter.cpp
@@ -81,6 +81,11 @@ struct RewriteStackElement {
NodeBuilder<> d_builder;
};
+RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n)
+{
+ return RewriteResponse(REWRITE_DONE, n);
+}
+
Node Rewriter::rewrite(TNode node) {
if (node.getNumChildren() == 0)
{
@@ -88,8 +93,35 @@ Node Rewriter::rewrite(TNode node) {
// eagerly for the sake of efficiency here.
return node;
}
- Rewriter& rewriter = getInstance();
- return rewriter.rewriteTo(theoryOf(node), node);
+ return getInstance().rewriteTo(theoryOf(node), node);
+}
+
+void Rewriter::registerPreRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ Assert(k != kind::EQUAL) << "Register pre-rewrites for EQUAL with registerPreRewriteEqual.";
+ d_preRewriters[k] = fn;
+}
+
+void Rewriter::registerPostRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ Assert(k != kind::EQUAL) << "Register post-rewrites for EQUAL with registerPostRewriteEqual.";
+ d_postRewriters[k] = fn;
+}
+
+void Rewriter::registerPreRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_preRewritersEqual[tid] = fn;
+}
+
+void Rewriter::registerPostRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_postRewritersEqual[tid] = fn;
}
Rewriter& Rewriter::getInstance()
@@ -153,8 +185,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
for(;;) {
// Perform the pre-rewrite
RewriteResponse response =
- d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite(
- rewriteStackTop.d_node);
+ preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node);
// Put the rewritten node to the top of the stack
rewriteStackTop.d_node = response.d_node;
TheoryId newTheory = theoryOf(rewriteStackTop.d_node);
@@ -225,8 +256,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
for(;;) {
// Do the post-rewrite
RewriteResponse response =
- d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite(
- rewriteStackTop.d_node);
+ postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node);
// We continue with the response we got
TheoryId newTheoryId = theoryOf(response.d_node);
if (newTheoryId != rewriteStackTop.getTheoryId()
@@ -290,6 +320,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
Unreachable();
}/* Rewriter::rewriteTo() */
+RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n)
+{
+ Kind k = n.getKind();
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+ (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k];
+ if (fn == nullptr)
+ {
+ return d_theoryRewriters[theoryId]->preRewrite(n);
+ }
+ return fn(&d_re, n);
+}
+
+RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n)
+{
+ Kind k = n.getKind();
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+ (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k];
+ if (fn == nullptr)
+ {
+ return d_theoryRewriters[theoryId]->postRewrite(n);
+ }
+ return fn(&d_re, n);
+}
+
void Rewriter::clearCaches() {
Rewriter& rewriter = getInstance();
diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h
index e55ca5d1c..f7298e1fb 100644
--- a/src/theory/rewriter.h
+++ b/src/theory/rewriter.h
@@ -28,6 +28,23 @@ namespace theory {
class RewriterInitializer;
/**
+ * The rewrite environment holds everything that the individual rewrites have
+ * access to.
+ */
+class RewriteEnvironment
+{
+};
+
+/**
+ * The identity rewrite just returns the original node.
+ *
+ * @param re The rewrite environment
+ * @param n The node to rewrite
+ * @return The original node
+ */
+RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n);
+
+/**
* The main rewriter class.
*/
class Rewriter {
@@ -45,6 +62,44 @@ class Rewriter {
*/
static void clearCaches();
+ /**
+ * Register a prerewrite for a given kind.
+ *
+ * @param k The kind to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPreRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+ /**
+ * Register a postrewrite for a given kind.
+ *
+ * @param k The kind to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPostRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+ /**
+ * Register a prerewrite for equalities belonging to a given theory.
+ *
+ * @param tid The theory to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPreRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+ /**
+ * Register a postrewrite for equalities belonging to a given theory.
+ *
+ * @param tid The theory to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPostRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
private:
/**
* Get the (singleton) instance of the rewriter.
@@ -71,10 +126,10 @@ class Rewriter {
Node rewriteTo(theory::TheoryId theoryId, Node node);
/** Calls the pre-rewriter for the given theory */
- RewriteResponse callPreRewrite(theory::TheoryId theoryId, TNode node);
+ RewriteResponse preRewrite(theory::TheoryId theoryId, TNode n);
/** Calls the post-rewriter for the given theory */
- RewriteResponse callPostRewrite(theory::TheoryId theoryId, TNode node);
+ RewriteResponse postRewrite(theory::TheoryId theoryId, TNode n);
/**
* Calls the equality-rewriter for the given theory.
@@ -88,6 +143,27 @@ class Rewriter {
unsigned long d_iterationCount = 0;
+ /** Rewriter table for prewrites. Maps kinds to rewriter function. */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_preRewriters[kind::LAST_KIND];
+ /** Rewriter table for postrewrites. Maps kinds to rewriter function. */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_postRewriters[kind::LAST_KIND];
+ /**
+ * Rewriter table for prerewrites of equalities. Maps theory to rewriter
+ * function.
+ */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_preRewritersEqual[theory::THEORY_LAST];
+ /**
+ * Rewriter table for postrewrites of equalities. Maps theory to rewriter
+ * function.
+ */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_postRewritersEqual[theory::THEORY_LAST];
+
+ RewriteEnvironment d_re;
+
#ifdef CVC4_ASSERTIONS
std::unique_ptr<std::unordered_set<Node, NodeHashFunction>> d_rewriteStack =
nullptr;
diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h
index e1be6355b..1bb03e253 100644
--- a/src/theory/rewriter_tables_template.h
+++ b/src/theory/rewriter_tables_template.h
@@ -64,6 +64,19 @@ ${post_rewrite_set_cache}
Rewriter::Rewriter()
{
${rewrite_init}
+
+for (size_t i = 0; i < kind::LAST_KIND; ++i)
+{
+ d_preRewriters[i] = nullptr;
+ d_postRewriters[i] = nullptr;
+}
+
+for (size_t i = 0; i < theory::THEORY_LAST; ++i)
+{
+ d_preRewritersEqual[i] = nullptr;
+ d_postRewritersEqual[i] = nullptr;
+ d_theoryRewriters[i]->registerRewrites(this);
+}
}
void Rewriter::clearCachesInternal() {
diff --git a/src/theory/theory_rewriter.h b/src/theory/theory_rewriter.h
index e7dc782bb..311ab9020 100644
--- a/src/theory/theory_rewriter.h
+++ b/src/theory/theory_rewriter.h
@@ -24,6 +24,8 @@
namespace CVC4 {
namespace theory {
+class Rewriter;
+
/**
* Theory rewriters signal whether more rewriting is needed (or not)
* by using a member of this enumeration. See RewriteResponse, below.
@@ -64,6 +66,13 @@ class TheoryRewriter
virtual ~TheoryRewriter() = default;
/**
+ * Registers the rewrites of a given theory with the rewriter.
+ *
+ * @param rewriter The rewriter to register the rewrites with.
+ */
+ virtual void registerRewrites(Rewriter* rewriter) {}
+
+ /**
* Performs a pre-rewrite step.
*
* @param node The node to rewrite
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback