diff options
author | Aina Niemetz <aina.niemetz@gmail.com> | 2020-03-11 11:02:45 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-11 11:02:45 -0700 |
commit | c34af1470198cb99e067615f98eceedf921fe1b2 (patch) | |
tree | 0076981a2f8caba321017c5a7a02e3ed81ffe72e /src/theory | |
parent | 80254c2bbfb679f419e8b50a2aa1a1cd51cbd295 (diff) | |
parent | 8a56e62da0a8940f0ae1ee9575398e5f21660097 (diff) |
Merge branch 'master' into issue4028issue4028
Diffstat (limited to 'src/theory')
-rw-r--r-- | src/theory/quantifiers/conjecture_generator.cpp | 10 | ||||
-rw-r--r-- | src/theory/quantifiers/local_theory_ext.cpp | 270 | ||||
-rw-r--r-- | src/theory/quantifiers/local_theory_ext.h | 93 | ||||
-rw-r--r-- | src/theory/quantifiers/quantifiers_rewriter.cpp | 1 | ||||
-rw-r--r-- | src/theory/quantifiers/single_inv_partition.cpp | 2 | ||||
-rw-r--r-- | src/theory/quantifiers/single_inv_partition.h | 2 | ||||
-rw-r--r-- | src/theory/quantifiers_engine.cpp | 9 | ||||
-rw-r--r-- | src/theory/rewriter.cpp | 66 | ||||
-rw-r--r-- | src/theory/rewriter.h | 80 | ||||
-rw-r--r-- | src/theory/rewriter_tables_template.h | 13 | ||||
-rw-r--r-- | src/theory/theory_rewriter.h | 9 |
11 files changed, 168 insertions, 387 deletions
diff --git a/src/theory/quantifiers/conjecture_generator.cpp b/src/theory/quantifiers/conjecture_generator.cpp index b82b958af..bccb33f1d 100644 --- a/src/theory/quantifiers/conjecture_generator.cpp +++ b/src/theory/quantifiers/conjecture_generator.cpp @@ -66,12 +66,14 @@ Node OpArgIndex::getGroundTerm( ConjectureGenerator * s, std::vector< TNode >& a } } return Node::null(); - }else{ - std::vector< TNode > args2; + } + std::vector<TNode> args2; + if (d_op_terms[0].getMetaKind() == kind::metakind::PARAMETERIZED) + { args2.push_back( d_ops[0] ); - args2.insert( args2.end(), args.begin(), args.end() ); - return NodeManager::currentNM()->mkNode( d_op_terms[0].getKind(), args2 ); } + args2.insert(args2.end(), args.begin(), args.end()); + return NodeManager::currentNM()->mkNode(d_op_terms[0].getKind(), args2); } void OpArgIndex::getGroundTerms( ConjectureGenerator * s, std::vector< TNode >& terms ) { diff --git a/src/theory/quantifiers/local_theory_ext.cpp b/src/theory/quantifiers/local_theory_ext.cpp deleted file mode 100644 index a3de5ced9..000000000 --- a/src/theory/quantifiers/local_theory_ext.cpp +++ /dev/null @@ -1,270 +0,0 @@ -/********************* */ -/*! \file local_theory_ext.cpp - ** \verbatim - ** Top contributors (to current version): - ** Andrew Reynolds, Morgan Deters, Paul Meng - ** This file is part of the CVC4 project. - ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS - ** in the top-level source directory) and their institutional affiliations. - ** All rights reserved. See the file COPYING in the top-level source - ** directory for licensing information.\endverbatim - ** - ** \brief Implementation of local theory ext utilities - **/ - -#include "theory/quantifiers/local_theory_ext.h" -#include "theory/quantifiers/term_database.h" -#include "theory/quantifiers/term_util.h" -#include "theory/quantifiers_engine.h" -#include "theory/quantifiers/first_order_model.h" - -using namespace std; -using namespace CVC4; -using namespace CVC4::kind; -using namespace CVC4::context; -using namespace CVC4::theory; -using namespace CVC4::theory::quantifiers; - - -LtePartialInst::LtePartialInst( QuantifiersEngine * qe, context::Context* c ) : -QuantifiersModule( qe ), d_wasInvoked( false ), d_needsCheck( false ){ - -} - -/** add quantifier */ -void LtePartialInst::checkOwnership(Node q) -{ - if( !q.getAttribute(LtePartialInstAttribute()) ){ - if( d_do_inst.find( q )!=d_do_inst.end() ){ - if( d_do_inst[q] ){ - d_lte_asserts.push_back( q ); - d_quantEngine->setOwner( q, this ); - } - }else{ - d_vars[q].clear(); - d_pat_var_order[q].clear(); - //check if this quantified formula is eligible for partial instantiation - std::map< Node, bool > vars; - for( unsigned i=0; i<q[0].getNumChildren(); i++ ){ - vars[q[0][i]] = false; - } - getEligibleInstVars( q[1], vars ); - - //instantiate only if we would force ground instances - std::map< Node, int > var_order; - bool doInst = true; - for( unsigned i=0; i<q[0].getNumChildren(); i++ ){ - if( vars[q[0][i]] ){ - d_vars[q].push_back( q[0][i] ); - var_order[q[0][i]] = i; - }else{ - Trace("lte-partial-inst-debug") << "...do not consider, variable " << q[0][i] << " was not found in correct position in body." << std::endl; - doInst = false; - break; - } - } - if( doInst ){ - //also needs patterns - if( q.getNumChildren()==3 && q[2].getNumChildren()==1 ){ - for( unsigned i=0; i<q[2][0].getNumChildren(); i++ ){ - Node pat = q[2][0][i]; - if( pat.getKind()==APPLY_UF ){ - for( unsigned j=0; j<pat.getNumChildren(); j++ ){ - if( !addVariableToPatternList( pat[j], d_pat_var_order[q], var_order ) ){ - doInst = false; - } - } - }else if( !addVariableToPatternList( pat, d_pat_var_order[q], var_order ) ){ - doInst = false; - } - if( !doInst ){ - Trace("lte-partial-inst-debug") << "...do not consider, cannot resolve pattern : " << pat << std::endl; - break; - } - } - }else{ - Trace("lte-partial-inst-debug") << "...do not consider (must have exactly one pattern)." << std::endl; - } - } - - - Trace("lte-partial-inst") << "LTE: ...will " << ( doInst ? "" : "not ") << "instantiate " << q << std::endl; - d_do_inst[q] = doInst; - if( doInst ){ - d_lte_asserts.push_back( q ); - d_needsCheck = true; - d_quantEngine->setOwner( q, this ); - } - } - } -} - -bool LtePartialInst::addVariableToPatternList( Node v, std::vector< int >& pat_var_order, std::map< Node, int >& var_order ) { - std::map< Node, int >::iterator it = var_order.find( v ); - if( it==var_order.end() ){ - return false; - }else if( std::find( pat_var_order.begin(), pat_var_order.end(), it->second )!=pat_var_order.end() ){ - return false; - }else{ - pat_var_order.push_back( it->second ); - return true; - } -} - -void LtePartialInst::getEligibleInstVars( Node n, std::map< Node, bool >& vars ) { - if( n.getKind()==APPLY_UF && !n.getType().isBoolean() ){ - for( unsigned i=0; i<n.getNumChildren(); i++ ){ - if( vars.find( n[i] )!=vars.end() ){ - vars[n[i]] = true; - } - } - } - for( unsigned i=0; i<n.getNumChildren(); i++ ){ - getEligibleInstVars( n[i], vars ); - } -} - -/* whether this module needs to check this round */ -bool LtePartialInst::needsCheck( Theory::Effort e ) { - return e>=Theory::EFFORT_FULL && d_needsCheck; -} -/* Call during quantifier engine's check */ -void LtePartialInst::check(Theory::Effort e, QEffort quant_e) -{ - //flush lemmas ASAP (they are a reduction) - if (quant_e == QEFFORT_CONFLICT && d_needsCheck) - { - std::vector< Node > lemmas; - getInstantiations( lemmas ); - //add lemmas to quantifiers engine - for( unsigned i=0; i<lemmas.size(); i++ ){ - d_quantEngine->addLemma( lemmas[i], false ); - } - d_needsCheck = false; - } -} - - -void LtePartialInst::reset() { - d_reps.clear(); - eq::EqualityEngine* ee = d_quantEngine->getActiveEqualityEngine(); - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( ee ); - while( !eqcs_i.isFinished() ){ - TNode r = (*eqcs_i); - TypeNode tn = r.getType(); - d_reps[tn].push_back( r ); - ++eqcs_i; - } -} - - -/** get instantiations */ -void LtePartialInst::getInstantiations( std::vector< Node >& lemmas ) { - Trace("lte-partial-inst") << "LTE : get instantiations, # quant = " << d_lte_asserts.size() << std::endl; - reset(); - for( unsigned i=0; i<d_lte_asserts.size(); i++ ){ - Node q = d_lte_asserts[i]; - Assert(d_do_inst.find(q) != d_do_inst.end() && d_do_inst[q]); - if( d_inst.find( q )==d_inst.end() ){ - Trace("lte-partial-inst") << "LTE : Get partial instantiations for " << q << "..." << std::endl; - d_inst[q] = true; - Assert(!d_vars[q].empty()); - //make bound list - Node bvl; - std::vector< Node > bvs; - for( unsigned j=0; j<q[0].getNumChildren(); j++ ){ - if( std::find( d_vars[q].begin(), d_vars[q].end(), q[0][j] )==d_vars[q].end() ){ - bvs.push_back( q[0][j] ); - } - } - if( !bvs.empty() ){ - bvl = NodeManager::currentNM()->mkNode( BOUND_VAR_LIST, bvs ); - } - std::vector< Node > conj; - std::vector< Node > terms; - std::vector< TypeNode > types; - for( unsigned j=0; j<d_vars[q].size(); j++ ){ - types.push_back( d_vars[q][j].getType() ); - terms.push_back( Node::null() ); - } - - getPartialInstantiations( conj, q, bvl, d_vars[q], terms, types, NULL, 0, 0, 0 ); - Assert(!conj.empty()); - lemmas.push_back( NodeManager::currentNM()->mkNode( OR, q.negate(), conj.size()==1 ? conj[0] : NodeManager::currentNM()->mkNode( AND, conj ) ) ); - d_wasInvoked = true; - } - } -} - -void LtePartialInst::getPartialInstantiations(std::vector<Node>& conj, - Node q, - Node bvl, - std::vector<Node>& vars, - std::vector<Node>& terms, - std::vector<TypeNode>& types, - TNodeTrie* curr, - unsigned pindex, - unsigned paindex, - unsigned iindex) -{ - if( iindex==vars.size() ){ - Node body = q[1].substitute( vars.begin(), vars.end(), terms.begin(), terms.end() ); - if( bvl.isNull() ){ - conj.push_back( body ); - Trace("lte-partial-inst") << " - ground conjunct : " << body << std::endl; - }else{ - Node nq; - if( q.getNumChildren()==3 ){ - Node ipl = q[2].substitute( vars.begin(), vars.end(), terms.begin(), terms.end() ); - nq = NodeManager::currentNM()->mkNode( FORALL, bvl, body, ipl ); - }else{ - nq = NodeManager::currentNM()->mkNode( FORALL, bvl, body ); - } - Trace("lte-partial-inst") << " - quantified conjunct : " << nq << std::endl; - LtePartialInstAttribute ltpia; - nq.setAttribute(ltpia,true); - conj.push_back( nq ); - } - }else{ - Assert(pindex < q[2][0].getNumChildren()); - Node pat = q[2][0][pindex]; - Assert(pat.getNumChildren() == 0 || paindex <= pat.getNumChildren()); - if( pat.getKind()==APPLY_UF ){ - Assert(paindex <= pat.getNumChildren()); - if( paindex==pat.getNumChildren() ){ - getPartialInstantiations( conj, q, bvl, vars, terms, types, NULL, pindex+1, 0, iindex ); - }else{ - if( !curr ){ - Assert(paindex == 0); - //start traversing term index for the operator - curr = d_quantEngine->getTermDatabase()->getTermArgTrie( pat.getOperator() ); - } - for (std::pair<const TNode, TNodeTrie>& t : curr->d_data) - { - terms[d_pat_var_order[q][iindex]] = t.first; - getPartialInstantiations(conj, - q, - bvl, - vars, - terms, - types, - &t.second, - pindex, - paindex + 1, - iindex + 1); - } - } - }else{ - std::map< TypeNode, std::vector< Node > >::iterator it = d_reps.find( types[iindex] ); - if( it!=d_reps.end() ){ - Trace("lte-partial-inst-debug") << it->second.size() << " reps of type " << types[iindex] << std::endl; - for( unsigned i=0; i<it->second.size(); i++ ){ - terms[d_pat_var_order[q][iindex]] = it->second[i]; - getPartialInstantiations( conj, q, bvl, vars, terms, types, NULL, pindex+1, 0, iindex+1 ); - } - }else{ - Trace("lte-partial-inst-debug") << "No reps found of type " << types[iindex] << std::endl; - } - } - } -} diff --git a/src/theory/quantifiers/local_theory_ext.h b/src/theory/quantifiers/local_theory_ext.h deleted file mode 100644 index d39ea3cfe..000000000 --- a/src/theory/quantifiers/local_theory_ext.h +++ /dev/null @@ -1,93 +0,0 @@ -/********************* */ -/*! \file local_theory_ext.h - ** \verbatim - ** Top contributors (to current version): - ** Andrew Reynolds, Mathias Preiner - ** This file is part of the CVC4 project. - ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS - ** in the top-level source directory) and their institutional affiliations. - ** All rights reserved. See the file COPYING in the top-level source - ** directory for licensing information.\endverbatim - ** - ** \brief local theory extensions util - **/ - -#include "cvc4_private.h" - -#ifndef CVC4__THEORY__LOCAL_THEORY_EXT_H -#define CVC4__THEORY__LOCAL_THEORY_EXT_H - -#include "context/cdo.h" -#include "expr/attribute.h" -#include "expr/node_trie.h" -#include "theory/quantifiers/quant_util.h" - -namespace CVC4 { -namespace theory { - -/** Attribute true for quantifiers that do not need to be partially instantiated - */ -struct LtePartialInstAttributeId -{ -}; -typedef expr::Attribute<LtePartialInstAttributeId, bool> - LtePartialInstAttribute; - -namespace quantifiers { - -class LtePartialInst : public QuantifiersModule { -private: - // was this module invoked - bool d_wasInvoked; - // needs check - bool d_needsCheck; - //representatives per type - std::map< TypeNode, std::vector< Node > > d_reps; - // should we instantiate quantifier - std::map< Node, bool > d_do_inst; - // have we instantiated quantifier - std::map< Node, bool > d_inst; - std::map< Node, std::vector< Node > > d_vars; - std::map< Node, std::vector< int > > d_pat_var_order; - /** list of relevant quantifiers asserted in the current context */ - std::vector< Node > d_lte_asserts; - /** reset */ - void reset(); - /** get instantiations */ - void getInstantiations( std::vector< Node >& lemmas ); - void getPartialInstantiations(std::vector<Node>& conj, - Node q, - Node bvl, - std::vector<Node>& vars, - std::vector<Node>& inst, - std::vector<TypeNode>& types, - TNodeTrie* curr, - unsigned pindex, - unsigned paindex, - unsigned iindex); - /** get eligible inst variables */ - void getEligibleInstVars( Node n, std::map< Node, bool >& vars ); - - bool addVariableToPatternList( Node v, std::vector< int >& pat_var_order, std::map< Node, int >& var_order ); -public: - LtePartialInst( QuantifiersEngine * qe, context::Context* c ); - /** determine whether this quantified formula will be reduced */ - void checkOwnership(Node q) override; - /** was invoked */ - bool wasInvoked() { return d_wasInvoked; } - - /* whether this module needs to check this round */ - bool needsCheck(Theory::Effort e) override; - /* Call during quantifier engine's check */ - void check(Theory::Effort e, QEffort quant_e) override; - /* check complete */ - bool checkComplete() override { return !d_wasInvoked; } - /** Identify this module (for debugging, dynamic configuration, etc..) */ - std::string identify() const override { return "LtePartialInst"; } -}; - -} -} -} - -#endif diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp index 6a54e8393..187c765d1 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.cpp +++ b/src/theory/quantifiers/quantifiers_rewriter.cpp @@ -1792,7 +1792,6 @@ Node QuantifiersRewriter::computeAggressiveMiniscoping( std::vector< Node >& arg } } Assert(!qvl1.empty()); - Assert(!qvl2.empty() || !qvsh.empty()); //check for literals that only contain shared variables std::vector<Node> qlitsh; std::vector<Node> qlit2; diff --git a/src/theory/quantifiers/single_inv_partition.cpp b/src/theory/quantifiers/single_inv_partition.cpp index a0e25b756..50831fdac 100644 --- a/src/theory/quantifiers/single_inv_partition.cpp +++ b/src/theory/quantifiers/single_inv_partition.cpp @@ -346,7 +346,7 @@ bool SingleInvocationPartition::init(std::vector<Node>& funcs, d_conjuncts[2].push_back(cr); std::unordered_set<Node, NodeHashFunction> fvs; expr::getFreeVariables(cr, fvs); - d_all_vars.insert(d_all_vars.end(), fvs.begin(), fvs.end()); + d_all_vars.insert(fvs.begin(), fvs.end()); if (singleInvocation) { // replace with single invocation formulation diff --git a/src/theory/quantifiers/single_inv_partition.h b/src/theory/quantifiers/single_inv_partition.h index 0a4af3185..cdc56d1f0 100644 --- a/src/theory/quantifiers/single_inv_partition.h +++ b/src/theory/quantifiers/single_inv_partition.h @@ -201,7 +201,7 @@ class SingleInvocationPartition std::vector<Node> d_si_vars; /** every free variable of conjuncts[2] */ - std::vector<Node> d_all_vars; + std::unordered_set<Node, NodeHashFunction> d_all_vars; /** map from functions to first-order variables that anti-skolemized them */ std::map<Node, Node> d_func_fo_var; /** map from first-order variables to the function it anti-skolemized */ diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp index ed4a79808..4339ee75f 100644 --- a/src/theory/quantifiers_engine.cpp +++ b/src/theory/quantifiers_engine.cpp @@ -25,7 +25,6 @@ #include "theory/quantifiers/fmf/full_model_check.h" #include "theory/quantifiers/fmf/model_engine.h" #include "theory/quantifiers/inst_strategy_enumerative.h" -#include "theory/quantifiers/local_theory_ext.h" #include "theory/quantifiers/quant_conflict_find.h" #include "theory/quantifiers/quant_split.h" #include "theory/quantifiers/quantifiers_rewriter.h" @@ -52,7 +51,6 @@ class QuantifiersEnginePrivate d_qcf(nullptr), d_sg_gen(nullptr), d_synth_e(nullptr), - d_lte_part_inst(nullptr), d_fs(nullptr), d_i_cbqi(nullptr), d_qsplit(nullptr), @@ -79,8 +77,6 @@ class QuantifiersEnginePrivate std::unique_ptr<quantifiers::ConjectureGenerator> d_sg_gen; /** ceg instantiation */ std::unique_ptr<quantifiers::SynthEngine> d_synth_e; - /** lte partial instantiation */ - std::unique_ptr<quantifiers::LtePartialInst> d_lte_part_inst; /** full saturation */ std::unique_ptr<quantifiers::InstStrategyEnum> d_fs; /** counterexample-based quantifier instantiation */ @@ -142,11 +138,6 @@ class QuantifiersEnginePrivate // finite model finder has special ways of building the model needsBuilder = true; } - if (options::ltePartialInst()) - { - d_lte_part_inst.reset(new quantifiers::LtePartialInst(qe, c)); - modules.push_back(d_lte_part_inst.get()); - } if (options::quantDynamicSplit() != options::QuantDSplitMode::NONE) { d_qsplit.reset(new quantifiers::QuantDSplit(qe, c)); diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index 765c2b4c8..b3f1e23d7 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -81,6 +81,11 @@ struct RewriteStackElement { NodeBuilder<> d_builder; }; +RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n) +{ + return RewriteResponse(REWRITE_DONE, n); +} + Node Rewriter::rewrite(TNode node) { if (node.getNumChildren() == 0) { @@ -88,8 +93,35 @@ Node Rewriter::rewrite(TNode node) { // eagerly for the sake of efficiency here. return node; } - Rewriter& rewriter = getInstance(); - return rewriter.rewriteTo(theoryOf(node), node); + return getInstance().rewriteTo(theoryOf(node), node); +} + +void Rewriter::registerPreRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + Assert(k != kind::EQUAL) << "Register pre-rewrites for EQUAL with registerPreRewriteEqual."; + d_preRewriters[k] = fn; +} + +void Rewriter::registerPostRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + Assert(k != kind::EQUAL) << "Register post-rewrites for EQUAL with registerPostRewriteEqual."; + d_postRewriters[k] = fn; +} + +void Rewriter::registerPreRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_preRewritersEqual[tid] = fn; +} + +void Rewriter::registerPostRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn) +{ + d_postRewritersEqual[tid] = fn; } Rewriter& Rewriter::getInstance() @@ -153,8 +185,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { for(;;) { // Perform the pre-rewrite RewriteResponse response = - d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite( - rewriteStackTop.d_node); + preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node); // Put the rewritten node to the top of the stack rewriteStackTop.d_node = response.d_node; TheoryId newTheory = theoryOf(rewriteStackTop.d_node); @@ -225,8 +256,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { for(;;) { // Do the post-rewrite RewriteResponse response = - d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite( - rewriteStackTop.d_node); + postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node); // We continue with the response we got TheoryId newTheoryId = theoryOf(response.d_node); if (newTheoryId != rewriteStackTop.getTheoryId() @@ -290,6 +320,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { Unreachable(); }/* Rewriter::rewriteTo() */ +RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n) +{ + Kind k = n.getKind(); + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn = + (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k]; + if (fn == nullptr) + { + return d_theoryRewriters[theoryId]->preRewrite(n); + } + return fn(&d_re, n); +} + +RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n) +{ + Kind k = n.getKind(); + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn = + (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k]; + if (fn == nullptr) + { + return d_theoryRewriters[theoryId]->postRewrite(n); + } + return fn(&d_re, n); +} + void Rewriter::clearCaches() { Rewriter& rewriter = getInstance(); diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h index e55ca5d1c..f7298e1fb 100644 --- a/src/theory/rewriter.h +++ b/src/theory/rewriter.h @@ -28,6 +28,23 @@ namespace theory { class RewriterInitializer; /** + * The rewrite environment holds everything that the individual rewrites have + * access to. + */ +class RewriteEnvironment +{ +}; + +/** + * The identity rewrite just returns the original node. + * + * @param re The rewrite environment + * @param n The node to rewrite + * @return The original node + */ +RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n); + +/** * The main rewriter class. */ class Rewriter { @@ -45,6 +62,44 @@ class Rewriter { */ static void clearCaches(); + /** + * Register a prerewrite for a given kind. + * + * @param k The kind to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPreRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn); + + /** + * Register a postrewrite for a given kind. + * + * @param k The kind to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPostRewrite( + Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn); + + /** + * Register a prerewrite for equalities belonging to a given theory. + * + * @param tid The theory to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPreRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn); + + /** + * Register a postrewrite for equalities belonging to a given theory. + * + * @param tid The theory to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPostRewriteEqual( + theory::TheoryId tid, + std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn); + private: /** * Get the (singleton) instance of the rewriter. @@ -71,10 +126,10 @@ class Rewriter { Node rewriteTo(theory::TheoryId theoryId, Node node); /** Calls the pre-rewriter for the given theory */ - RewriteResponse callPreRewrite(theory::TheoryId theoryId, TNode node); + RewriteResponse preRewrite(theory::TheoryId theoryId, TNode n); /** Calls the post-rewriter for the given theory */ - RewriteResponse callPostRewrite(theory::TheoryId theoryId, TNode node); + RewriteResponse postRewrite(theory::TheoryId theoryId, TNode n); /** * Calls the equality-rewriter for the given theory. @@ -88,6 +143,27 @@ class Rewriter { unsigned long d_iterationCount = 0; + /** Rewriter table for prewrites. Maps kinds to rewriter function. */ + std::function<RewriteResponse(RewriteEnvironment*, TNode)> + d_preRewriters[kind::LAST_KIND]; + /** Rewriter table for postrewrites. Maps kinds to rewriter function. */ + std::function<RewriteResponse(RewriteEnvironment*, TNode)> + d_postRewriters[kind::LAST_KIND]; + /** + * Rewriter table for prerewrites of equalities. Maps theory to rewriter + * function. + */ + std::function<RewriteResponse(RewriteEnvironment*, TNode)> + d_preRewritersEqual[theory::THEORY_LAST]; + /** + * Rewriter table for postrewrites of equalities. Maps theory to rewriter + * function. + */ + std::function<RewriteResponse(RewriteEnvironment*, TNode)> + d_postRewritersEqual[theory::THEORY_LAST]; + + RewriteEnvironment d_re; + #ifdef CVC4_ASSERTIONS std::unique_ptr<std::unordered_set<Node, NodeHashFunction>> d_rewriteStack = nullptr; diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h index e1be6355b..1bb03e253 100644 --- a/src/theory/rewriter_tables_template.h +++ b/src/theory/rewriter_tables_template.h @@ -64,6 +64,19 @@ ${post_rewrite_set_cache} Rewriter::Rewriter() { ${rewrite_init} + +for (size_t i = 0; i < kind::LAST_KIND; ++i) +{ + d_preRewriters[i] = nullptr; + d_postRewriters[i] = nullptr; +} + +for (size_t i = 0; i < theory::THEORY_LAST; ++i) +{ + d_preRewritersEqual[i] = nullptr; + d_postRewritersEqual[i] = nullptr; + d_theoryRewriters[i]->registerRewrites(this); +} } void Rewriter::clearCachesInternal() { diff --git a/src/theory/theory_rewriter.h b/src/theory/theory_rewriter.h index e7dc782bb..311ab9020 100644 --- a/src/theory/theory_rewriter.h +++ b/src/theory/theory_rewriter.h @@ -24,6 +24,8 @@ namespace CVC4 { namespace theory { +class Rewriter; + /** * Theory rewriters signal whether more rewriting is needed (or not) * by using a member of this enumeration. See RewriteResponse, below. @@ -64,6 +66,13 @@ class TheoryRewriter virtual ~TheoryRewriter() = default; /** + * Registers the rewrites of a given theory with the rewriter. + * + * @param rewriter The rewriter to register the rewrites with. + */ + virtual void registerRewrites(Rewriter* rewriter) {} + + /** * Performs a pre-rewrite step. * * @param node The node to rewrite |