summaryrefslogtreecommitdiff
path: root/src/theory
diff options
context:
space:
mode:
authorAina Niemetz <aina.niemetz@gmail.com>2020-03-11 11:02:45 -0700
committerGitHub <noreply@github.com>2020-03-11 11:02:45 -0700
commitc34af1470198cb99e067615f98eceedf921fe1b2 (patch)
tree0076981a2f8caba321017c5a7a02e3ed81ffe72e /src/theory
parent80254c2bbfb679f419e8b50a2aa1a1cd51cbd295 (diff)
parent8a56e62da0a8940f0ae1ee9575398e5f21660097 (diff)
Merge branch 'master' into issue4028issue4028
Diffstat (limited to 'src/theory')
-rw-r--r--src/theory/quantifiers/conjecture_generator.cpp10
-rw-r--r--src/theory/quantifiers/local_theory_ext.cpp270
-rw-r--r--src/theory/quantifiers/local_theory_ext.h93
-rw-r--r--src/theory/quantifiers/quantifiers_rewriter.cpp1
-rw-r--r--src/theory/quantifiers/single_inv_partition.cpp2
-rw-r--r--src/theory/quantifiers/single_inv_partition.h2
-rw-r--r--src/theory/quantifiers_engine.cpp9
-rw-r--r--src/theory/rewriter.cpp66
-rw-r--r--src/theory/rewriter.h80
-rw-r--r--src/theory/rewriter_tables_template.h13
-rw-r--r--src/theory/theory_rewriter.h9
11 files changed, 168 insertions, 387 deletions
diff --git a/src/theory/quantifiers/conjecture_generator.cpp b/src/theory/quantifiers/conjecture_generator.cpp
index b82b958af..bccb33f1d 100644
--- a/src/theory/quantifiers/conjecture_generator.cpp
+++ b/src/theory/quantifiers/conjecture_generator.cpp
@@ -66,12 +66,14 @@ Node OpArgIndex::getGroundTerm( ConjectureGenerator * s, std::vector< TNode >& a
}
}
return Node::null();
- }else{
- std::vector< TNode > args2;
+ }
+ std::vector<TNode> args2;
+ if (d_op_terms[0].getMetaKind() == kind::metakind::PARAMETERIZED)
+ {
args2.push_back( d_ops[0] );
- args2.insert( args2.end(), args.begin(), args.end() );
- return NodeManager::currentNM()->mkNode( d_op_terms[0].getKind(), args2 );
}
+ args2.insert(args2.end(), args.begin(), args.end());
+ return NodeManager::currentNM()->mkNode(d_op_terms[0].getKind(), args2);
}
void OpArgIndex::getGroundTerms( ConjectureGenerator * s, std::vector< TNode >& terms ) {
diff --git a/src/theory/quantifiers/local_theory_ext.cpp b/src/theory/quantifiers/local_theory_ext.cpp
deleted file mode 100644
index a3de5ced9..000000000
--- a/src/theory/quantifiers/local_theory_ext.cpp
+++ /dev/null
@@ -1,270 +0,0 @@
-/********************* */
-/*! \file local_theory_ext.cpp
- ** \verbatim
- ** Top contributors (to current version):
- ** Andrew Reynolds, Morgan Deters, Paul Meng
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved. See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Implementation of local theory ext utilities
- **/
-
-#include "theory/quantifiers/local_theory_ext.h"
-#include "theory/quantifiers/term_database.h"
-#include "theory/quantifiers/term_util.h"
-#include "theory/quantifiers_engine.h"
-#include "theory/quantifiers/first_order_model.h"
-
-using namespace std;
-using namespace CVC4;
-using namespace CVC4::kind;
-using namespace CVC4::context;
-using namespace CVC4::theory;
-using namespace CVC4::theory::quantifiers;
-
-
-LtePartialInst::LtePartialInst( QuantifiersEngine * qe, context::Context* c ) :
-QuantifiersModule( qe ), d_wasInvoked( false ), d_needsCheck( false ){
-
-}
-
-/** add quantifier */
-void LtePartialInst::checkOwnership(Node q)
-{
- if( !q.getAttribute(LtePartialInstAttribute()) ){
- if( d_do_inst.find( q )!=d_do_inst.end() ){
- if( d_do_inst[q] ){
- d_lte_asserts.push_back( q );
- d_quantEngine->setOwner( q, this );
- }
- }else{
- d_vars[q].clear();
- d_pat_var_order[q].clear();
- //check if this quantified formula is eligible for partial instantiation
- std::map< Node, bool > vars;
- for( unsigned i=0; i<q[0].getNumChildren(); i++ ){
- vars[q[0][i]] = false;
- }
- getEligibleInstVars( q[1], vars );
-
- //instantiate only if we would force ground instances
- std::map< Node, int > var_order;
- bool doInst = true;
- for( unsigned i=0; i<q[0].getNumChildren(); i++ ){
- if( vars[q[0][i]] ){
- d_vars[q].push_back( q[0][i] );
- var_order[q[0][i]] = i;
- }else{
- Trace("lte-partial-inst-debug") << "...do not consider, variable " << q[0][i] << " was not found in correct position in body." << std::endl;
- doInst = false;
- break;
- }
- }
- if( doInst ){
- //also needs patterns
- if( q.getNumChildren()==3 && q[2].getNumChildren()==1 ){
- for( unsigned i=0; i<q[2][0].getNumChildren(); i++ ){
- Node pat = q[2][0][i];
- if( pat.getKind()==APPLY_UF ){
- for( unsigned j=0; j<pat.getNumChildren(); j++ ){
- if( !addVariableToPatternList( pat[j], d_pat_var_order[q], var_order ) ){
- doInst = false;
- }
- }
- }else if( !addVariableToPatternList( pat, d_pat_var_order[q], var_order ) ){
- doInst = false;
- }
- if( !doInst ){
- Trace("lte-partial-inst-debug") << "...do not consider, cannot resolve pattern : " << pat << std::endl;
- break;
- }
- }
- }else{
- Trace("lte-partial-inst-debug") << "...do not consider (must have exactly one pattern)." << std::endl;
- }
- }
-
-
- Trace("lte-partial-inst") << "LTE: ...will " << ( doInst ? "" : "not ") << "instantiate " << q << std::endl;
- d_do_inst[q] = doInst;
- if( doInst ){
- d_lte_asserts.push_back( q );
- d_needsCheck = true;
- d_quantEngine->setOwner( q, this );
- }
- }
- }
-}
-
-bool LtePartialInst::addVariableToPatternList( Node v, std::vector< int >& pat_var_order, std::map< Node, int >& var_order ) {
- std::map< Node, int >::iterator it = var_order.find( v );
- if( it==var_order.end() ){
- return false;
- }else if( std::find( pat_var_order.begin(), pat_var_order.end(), it->second )!=pat_var_order.end() ){
- return false;
- }else{
- pat_var_order.push_back( it->second );
- return true;
- }
-}
-
-void LtePartialInst::getEligibleInstVars( Node n, std::map< Node, bool >& vars ) {
- if( n.getKind()==APPLY_UF && !n.getType().isBoolean() ){
- for( unsigned i=0; i<n.getNumChildren(); i++ ){
- if( vars.find( n[i] )!=vars.end() ){
- vars[n[i]] = true;
- }
- }
- }
- for( unsigned i=0; i<n.getNumChildren(); i++ ){
- getEligibleInstVars( n[i], vars );
- }
-}
-
-/* whether this module needs to check this round */
-bool LtePartialInst::needsCheck( Theory::Effort e ) {
- return e>=Theory::EFFORT_FULL && d_needsCheck;
-}
-/* Call during quantifier engine's check */
-void LtePartialInst::check(Theory::Effort e, QEffort quant_e)
-{
- //flush lemmas ASAP (they are a reduction)
- if (quant_e == QEFFORT_CONFLICT && d_needsCheck)
- {
- std::vector< Node > lemmas;
- getInstantiations( lemmas );
- //add lemmas to quantifiers engine
- for( unsigned i=0; i<lemmas.size(); i++ ){
- d_quantEngine->addLemma( lemmas[i], false );
- }
- d_needsCheck = false;
- }
-}
-
-
-void LtePartialInst::reset() {
- d_reps.clear();
- eq::EqualityEngine* ee = d_quantEngine->getActiveEqualityEngine();
- eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( ee );
- while( !eqcs_i.isFinished() ){
- TNode r = (*eqcs_i);
- TypeNode tn = r.getType();
- d_reps[tn].push_back( r );
- ++eqcs_i;
- }
-}
-
-
-/** get instantiations */
-void LtePartialInst::getInstantiations( std::vector< Node >& lemmas ) {
- Trace("lte-partial-inst") << "LTE : get instantiations, # quant = " << d_lte_asserts.size() << std::endl;
- reset();
- for( unsigned i=0; i<d_lte_asserts.size(); i++ ){
- Node q = d_lte_asserts[i];
- Assert(d_do_inst.find(q) != d_do_inst.end() && d_do_inst[q]);
- if( d_inst.find( q )==d_inst.end() ){
- Trace("lte-partial-inst") << "LTE : Get partial instantiations for " << q << "..." << std::endl;
- d_inst[q] = true;
- Assert(!d_vars[q].empty());
- //make bound list
- Node bvl;
- std::vector< Node > bvs;
- for( unsigned j=0; j<q[0].getNumChildren(); j++ ){
- if( std::find( d_vars[q].begin(), d_vars[q].end(), q[0][j] )==d_vars[q].end() ){
- bvs.push_back( q[0][j] );
- }
- }
- if( !bvs.empty() ){
- bvl = NodeManager::currentNM()->mkNode( BOUND_VAR_LIST, bvs );
- }
- std::vector< Node > conj;
- std::vector< Node > terms;
- std::vector< TypeNode > types;
- for( unsigned j=0; j<d_vars[q].size(); j++ ){
- types.push_back( d_vars[q][j].getType() );
- terms.push_back( Node::null() );
- }
-
- getPartialInstantiations( conj, q, bvl, d_vars[q], terms, types, NULL, 0, 0, 0 );
- Assert(!conj.empty());
- lemmas.push_back( NodeManager::currentNM()->mkNode( OR, q.negate(), conj.size()==1 ? conj[0] : NodeManager::currentNM()->mkNode( AND, conj ) ) );
- d_wasInvoked = true;
- }
- }
-}
-
-void LtePartialInst::getPartialInstantiations(std::vector<Node>& conj,
- Node q,
- Node bvl,
- std::vector<Node>& vars,
- std::vector<Node>& terms,
- std::vector<TypeNode>& types,
- TNodeTrie* curr,
- unsigned pindex,
- unsigned paindex,
- unsigned iindex)
-{
- if( iindex==vars.size() ){
- Node body = q[1].substitute( vars.begin(), vars.end(), terms.begin(), terms.end() );
- if( bvl.isNull() ){
- conj.push_back( body );
- Trace("lte-partial-inst") << " - ground conjunct : " << body << std::endl;
- }else{
- Node nq;
- if( q.getNumChildren()==3 ){
- Node ipl = q[2].substitute( vars.begin(), vars.end(), terms.begin(), terms.end() );
- nq = NodeManager::currentNM()->mkNode( FORALL, bvl, body, ipl );
- }else{
- nq = NodeManager::currentNM()->mkNode( FORALL, bvl, body );
- }
- Trace("lte-partial-inst") << " - quantified conjunct : " << nq << std::endl;
- LtePartialInstAttribute ltpia;
- nq.setAttribute(ltpia,true);
- conj.push_back( nq );
- }
- }else{
- Assert(pindex < q[2][0].getNumChildren());
- Node pat = q[2][0][pindex];
- Assert(pat.getNumChildren() == 0 || paindex <= pat.getNumChildren());
- if( pat.getKind()==APPLY_UF ){
- Assert(paindex <= pat.getNumChildren());
- if( paindex==pat.getNumChildren() ){
- getPartialInstantiations( conj, q, bvl, vars, terms, types, NULL, pindex+1, 0, iindex );
- }else{
- if( !curr ){
- Assert(paindex == 0);
- //start traversing term index for the operator
- curr = d_quantEngine->getTermDatabase()->getTermArgTrie( pat.getOperator() );
- }
- for (std::pair<const TNode, TNodeTrie>& t : curr->d_data)
- {
- terms[d_pat_var_order[q][iindex]] = t.first;
- getPartialInstantiations(conj,
- q,
- bvl,
- vars,
- terms,
- types,
- &t.second,
- pindex,
- paindex + 1,
- iindex + 1);
- }
- }
- }else{
- std::map< TypeNode, std::vector< Node > >::iterator it = d_reps.find( types[iindex] );
- if( it!=d_reps.end() ){
- Trace("lte-partial-inst-debug") << it->second.size() << " reps of type " << types[iindex] << std::endl;
- for( unsigned i=0; i<it->second.size(); i++ ){
- terms[d_pat_var_order[q][iindex]] = it->second[i];
- getPartialInstantiations( conj, q, bvl, vars, terms, types, NULL, pindex+1, 0, iindex+1 );
- }
- }else{
- Trace("lte-partial-inst-debug") << "No reps found of type " << types[iindex] << std::endl;
- }
- }
- }
-}
diff --git a/src/theory/quantifiers/local_theory_ext.h b/src/theory/quantifiers/local_theory_ext.h
deleted file mode 100644
index d39ea3cfe..000000000
--- a/src/theory/quantifiers/local_theory_ext.h
+++ /dev/null
@@ -1,93 +0,0 @@
-/********************* */
-/*! \file local_theory_ext.h
- ** \verbatim
- ** Top contributors (to current version):
- ** Andrew Reynolds, Mathias Preiner
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved. See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief local theory extensions util
- **/
-
-#include "cvc4_private.h"
-
-#ifndef CVC4__THEORY__LOCAL_THEORY_EXT_H
-#define CVC4__THEORY__LOCAL_THEORY_EXT_H
-
-#include "context/cdo.h"
-#include "expr/attribute.h"
-#include "expr/node_trie.h"
-#include "theory/quantifiers/quant_util.h"
-
-namespace CVC4 {
-namespace theory {
-
-/** Attribute true for quantifiers that do not need to be partially instantiated
- */
-struct LtePartialInstAttributeId
-{
-};
-typedef expr::Attribute<LtePartialInstAttributeId, bool>
- LtePartialInstAttribute;
-
-namespace quantifiers {
-
-class LtePartialInst : public QuantifiersModule {
-private:
- // was this module invoked
- bool d_wasInvoked;
- // needs check
- bool d_needsCheck;
- //representatives per type
- std::map< TypeNode, std::vector< Node > > d_reps;
- // should we instantiate quantifier
- std::map< Node, bool > d_do_inst;
- // have we instantiated quantifier
- std::map< Node, bool > d_inst;
- std::map< Node, std::vector< Node > > d_vars;
- std::map< Node, std::vector< int > > d_pat_var_order;
- /** list of relevant quantifiers asserted in the current context */
- std::vector< Node > d_lte_asserts;
- /** reset */
- void reset();
- /** get instantiations */
- void getInstantiations( std::vector< Node >& lemmas );
- void getPartialInstantiations(std::vector<Node>& conj,
- Node q,
- Node bvl,
- std::vector<Node>& vars,
- std::vector<Node>& inst,
- std::vector<TypeNode>& types,
- TNodeTrie* curr,
- unsigned pindex,
- unsigned paindex,
- unsigned iindex);
- /** get eligible inst variables */
- void getEligibleInstVars( Node n, std::map< Node, bool >& vars );
-
- bool addVariableToPatternList( Node v, std::vector< int >& pat_var_order, std::map< Node, int >& var_order );
-public:
- LtePartialInst( QuantifiersEngine * qe, context::Context* c );
- /** determine whether this quantified formula will be reduced */
- void checkOwnership(Node q) override;
- /** was invoked */
- bool wasInvoked() { return d_wasInvoked; }
-
- /* whether this module needs to check this round */
- bool needsCheck(Theory::Effort e) override;
- /* Call during quantifier engine's check */
- void check(Theory::Effort e, QEffort quant_e) override;
- /* check complete */
- bool checkComplete() override { return !d_wasInvoked; }
- /** Identify this module (for debugging, dynamic configuration, etc..) */
- std::string identify() const override { return "LtePartialInst"; }
-};
-
-}
-}
-}
-
-#endif
diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp
index 6a54e8393..187c765d1 100644
--- a/src/theory/quantifiers/quantifiers_rewriter.cpp
+++ b/src/theory/quantifiers/quantifiers_rewriter.cpp
@@ -1792,7 +1792,6 @@ Node QuantifiersRewriter::computeAggressiveMiniscoping( std::vector< Node >& arg
}
}
Assert(!qvl1.empty());
- Assert(!qvl2.empty() || !qvsh.empty());
//check for literals that only contain shared variables
std::vector<Node> qlitsh;
std::vector<Node> qlit2;
diff --git a/src/theory/quantifiers/single_inv_partition.cpp b/src/theory/quantifiers/single_inv_partition.cpp
index a0e25b756..50831fdac 100644
--- a/src/theory/quantifiers/single_inv_partition.cpp
+++ b/src/theory/quantifiers/single_inv_partition.cpp
@@ -346,7 +346,7 @@ bool SingleInvocationPartition::init(std::vector<Node>& funcs,
d_conjuncts[2].push_back(cr);
std::unordered_set<Node, NodeHashFunction> fvs;
expr::getFreeVariables(cr, fvs);
- d_all_vars.insert(d_all_vars.end(), fvs.begin(), fvs.end());
+ d_all_vars.insert(fvs.begin(), fvs.end());
if (singleInvocation)
{
// replace with single invocation formulation
diff --git a/src/theory/quantifiers/single_inv_partition.h b/src/theory/quantifiers/single_inv_partition.h
index 0a4af3185..cdc56d1f0 100644
--- a/src/theory/quantifiers/single_inv_partition.h
+++ b/src/theory/quantifiers/single_inv_partition.h
@@ -201,7 +201,7 @@ class SingleInvocationPartition
std::vector<Node> d_si_vars;
/** every free variable of conjuncts[2] */
- std::vector<Node> d_all_vars;
+ std::unordered_set<Node, NodeHashFunction> d_all_vars;
/** map from functions to first-order variables that anti-skolemized them */
std::map<Node, Node> d_func_fo_var;
/** map from first-order variables to the function it anti-skolemized */
diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp
index ed4a79808..4339ee75f 100644
--- a/src/theory/quantifiers_engine.cpp
+++ b/src/theory/quantifiers_engine.cpp
@@ -25,7 +25,6 @@
#include "theory/quantifiers/fmf/full_model_check.h"
#include "theory/quantifiers/fmf/model_engine.h"
#include "theory/quantifiers/inst_strategy_enumerative.h"
-#include "theory/quantifiers/local_theory_ext.h"
#include "theory/quantifiers/quant_conflict_find.h"
#include "theory/quantifiers/quant_split.h"
#include "theory/quantifiers/quantifiers_rewriter.h"
@@ -52,7 +51,6 @@ class QuantifiersEnginePrivate
d_qcf(nullptr),
d_sg_gen(nullptr),
d_synth_e(nullptr),
- d_lte_part_inst(nullptr),
d_fs(nullptr),
d_i_cbqi(nullptr),
d_qsplit(nullptr),
@@ -79,8 +77,6 @@ class QuantifiersEnginePrivate
std::unique_ptr<quantifiers::ConjectureGenerator> d_sg_gen;
/** ceg instantiation */
std::unique_ptr<quantifiers::SynthEngine> d_synth_e;
- /** lte partial instantiation */
- std::unique_ptr<quantifiers::LtePartialInst> d_lte_part_inst;
/** full saturation */
std::unique_ptr<quantifiers::InstStrategyEnum> d_fs;
/** counterexample-based quantifier instantiation */
@@ -142,11 +138,6 @@ class QuantifiersEnginePrivate
// finite model finder has special ways of building the model
needsBuilder = true;
}
- if (options::ltePartialInst())
- {
- d_lte_part_inst.reset(new quantifiers::LtePartialInst(qe, c));
- modules.push_back(d_lte_part_inst.get());
- }
if (options::quantDynamicSplit() != options::QuantDSplitMode::NONE)
{
d_qsplit.reset(new quantifiers::QuantDSplit(qe, c));
diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp
index 765c2b4c8..b3f1e23d7 100644
--- a/src/theory/rewriter.cpp
+++ b/src/theory/rewriter.cpp
@@ -81,6 +81,11 @@ struct RewriteStackElement {
NodeBuilder<> d_builder;
};
+RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n)
+{
+ return RewriteResponse(REWRITE_DONE, n);
+}
+
Node Rewriter::rewrite(TNode node) {
if (node.getNumChildren() == 0)
{
@@ -88,8 +93,35 @@ Node Rewriter::rewrite(TNode node) {
// eagerly for the sake of efficiency here.
return node;
}
- Rewriter& rewriter = getInstance();
- return rewriter.rewriteTo(theoryOf(node), node);
+ return getInstance().rewriteTo(theoryOf(node), node);
+}
+
+void Rewriter::registerPreRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ Assert(k != kind::EQUAL) << "Register pre-rewrites for EQUAL with registerPreRewriteEqual.";
+ d_preRewriters[k] = fn;
+}
+
+void Rewriter::registerPostRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ Assert(k != kind::EQUAL) << "Register post-rewrites for EQUAL with registerPostRewriteEqual.";
+ d_postRewriters[k] = fn;
+}
+
+void Rewriter::registerPreRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_preRewritersEqual[tid] = fn;
+}
+
+void Rewriter::registerPostRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+ d_postRewritersEqual[tid] = fn;
}
Rewriter& Rewriter::getInstance()
@@ -153,8 +185,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
for(;;) {
// Perform the pre-rewrite
RewriteResponse response =
- d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite(
- rewriteStackTop.d_node);
+ preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node);
// Put the rewritten node to the top of the stack
rewriteStackTop.d_node = response.d_node;
TheoryId newTheory = theoryOf(rewriteStackTop.d_node);
@@ -225,8 +256,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
for(;;) {
// Do the post-rewrite
RewriteResponse response =
- d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite(
- rewriteStackTop.d_node);
+ postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node);
// We continue with the response we got
TheoryId newTheoryId = theoryOf(response.d_node);
if (newTheoryId != rewriteStackTop.getTheoryId()
@@ -290,6 +320,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
Unreachable();
}/* Rewriter::rewriteTo() */
+RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n)
+{
+ Kind k = n.getKind();
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+ (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k];
+ if (fn == nullptr)
+ {
+ return d_theoryRewriters[theoryId]->preRewrite(n);
+ }
+ return fn(&d_re, n);
+}
+
+RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n)
+{
+ Kind k = n.getKind();
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+ (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k];
+ if (fn == nullptr)
+ {
+ return d_theoryRewriters[theoryId]->postRewrite(n);
+ }
+ return fn(&d_re, n);
+}
+
void Rewriter::clearCaches() {
Rewriter& rewriter = getInstance();
diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h
index e55ca5d1c..f7298e1fb 100644
--- a/src/theory/rewriter.h
+++ b/src/theory/rewriter.h
@@ -28,6 +28,23 @@ namespace theory {
class RewriterInitializer;
/**
+ * The rewrite environment holds everything that the individual rewrites have
+ * access to.
+ */
+class RewriteEnvironment
+{
+};
+
+/**
+ * The identity rewrite just returns the original node.
+ *
+ * @param re The rewrite environment
+ * @param n The node to rewrite
+ * @return The original node
+ */
+RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n);
+
+/**
* The main rewriter class.
*/
class Rewriter {
@@ -45,6 +62,44 @@ class Rewriter {
*/
static void clearCaches();
+ /**
+ * Register a prerewrite for a given kind.
+ *
+ * @param k The kind to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPreRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+ /**
+ * Register a postrewrite for a given kind.
+ *
+ * @param k The kind to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPostRewrite(
+ Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+ /**
+ * Register a prerewrite for equalities belonging to a given theory.
+ *
+ * @param tid The theory to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPreRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+ /**
+ * Register a postrewrite for equalities belonging to a given theory.
+ *
+ * @param tid The theory to register a rewrite for.
+ * @param fn The function that performs the rewrite.
+ */
+ void registerPostRewriteEqual(
+ theory::TheoryId tid,
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
private:
/**
* Get the (singleton) instance of the rewriter.
@@ -71,10 +126,10 @@ class Rewriter {
Node rewriteTo(theory::TheoryId theoryId, Node node);
/** Calls the pre-rewriter for the given theory */
- RewriteResponse callPreRewrite(theory::TheoryId theoryId, TNode node);
+ RewriteResponse preRewrite(theory::TheoryId theoryId, TNode n);
/** Calls the post-rewriter for the given theory */
- RewriteResponse callPostRewrite(theory::TheoryId theoryId, TNode node);
+ RewriteResponse postRewrite(theory::TheoryId theoryId, TNode n);
/**
* Calls the equality-rewriter for the given theory.
@@ -88,6 +143,27 @@ class Rewriter {
unsigned long d_iterationCount = 0;
+ /** Rewriter table for prewrites. Maps kinds to rewriter function. */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_preRewriters[kind::LAST_KIND];
+ /** Rewriter table for postrewrites. Maps kinds to rewriter function. */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_postRewriters[kind::LAST_KIND];
+ /**
+ * Rewriter table for prerewrites of equalities. Maps theory to rewriter
+ * function.
+ */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_preRewritersEqual[theory::THEORY_LAST];
+ /**
+ * Rewriter table for postrewrites of equalities. Maps theory to rewriter
+ * function.
+ */
+ std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+ d_postRewritersEqual[theory::THEORY_LAST];
+
+ RewriteEnvironment d_re;
+
#ifdef CVC4_ASSERTIONS
std::unique_ptr<std::unordered_set<Node, NodeHashFunction>> d_rewriteStack =
nullptr;
diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h
index e1be6355b..1bb03e253 100644
--- a/src/theory/rewriter_tables_template.h
+++ b/src/theory/rewriter_tables_template.h
@@ -64,6 +64,19 @@ ${post_rewrite_set_cache}
Rewriter::Rewriter()
{
${rewrite_init}
+
+for (size_t i = 0; i < kind::LAST_KIND; ++i)
+{
+ d_preRewriters[i] = nullptr;
+ d_postRewriters[i] = nullptr;
+}
+
+for (size_t i = 0; i < theory::THEORY_LAST; ++i)
+{
+ d_preRewritersEqual[i] = nullptr;
+ d_postRewritersEqual[i] = nullptr;
+ d_theoryRewriters[i]->registerRewrites(this);
+}
}
void Rewriter::clearCachesInternal() {
diff --git a/src/theory/theory_rewriter.h b/src/theory/theory_rewriter.h
index e7dc782bb..311ab9020 100644
--- a/src/theory/theory_rewriter.h
+++ b/src/theory/theory_rewriter.h
@@ -24,6 +24,8 @@
namespace CVC4 {
namespace theory {
+class Rewriter;
+
/**
* Theory rewriters signal whether more rewriting is needed (or not)
* by using a member of this enumeration. See RewriteResponse, below.
@@ -64,6 +66,13 @@ class TheoryRewriter
virtual ~TheoryRewriter() = default;
/**
+ * Registers the rewrites of a given theory with the rewriter.
+ *
+ * @param rewriter The rewriter to register the rewrites with.
+ */
+ virtual void registerRewrites(Rewriter* rewriter) {}
+
+ /**
* Performs a pre-rewrite step.
*
* @param node The node to rewrite
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback