From 70f0cddbce01fa17622b7b70b638794181aefec5 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Thu, 26 Nov 2020 04:43:19 -0600 Subject: Move expand definitions to its own file (#5528) We are changing our policy on how expand definitions is handled during preprocessing. This will require some additions to expand definitions handling. This PR makes a standalone module for expanding definitions. --- src/CMakeLists.txt | 2 + src/smt/expand_definitions.cpp | 278 +++++++++++++++++++++++++++++++++++++++++ src/smt/expand_definitions.h | 76 +++++++++++ src/smt/preprocessor.cpp | 9 +- src/smt/preprocessor.h | 3 + src/smt/process_assertions.cpp | 234 +--------------------------------- src/smt/process_assertions.h | 17 +-- 7 files changed, 371 insertions(+), 248 deletions(-) create mode 100644 src/smt/expand_definitions.cpp create mode 100644 src/smt/expand_definitions.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 70cb68431..869699ac5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -219,6 +219,8 @@ libcvc4_add_sources( smt/dump.h smt/dump_manager.cpp smt/dump_manager.h + smt/expand_definitions.cpp + smt/expand_definitions.h smt/listeners.cpp smt/listeners.h smt/logic_exception.h diff --git a/src/smt/expand_definitions.cpp b/src/smt/expand_definitions.cpp new file mode 100644 index 000000000..20c4f8ef6 --- /dev/null +++ b/src/smt/expand_definitions.cpp @@ -0,0 +1,278 @@ +/********************* */ +/*! \file expand_definitions.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds, Tim King, Haniel Barbosa + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of expand definitions for an SMT engine. + **/ + +#include "smt/expand_definitions.h" + +#include +#include + +#include "expr/node_manager_attributes.h" +#include "smt/defined_function.h" +#include "smt/smt_engine.h" +#include "theory/theory_engine.h" + +using namespace CVC4::preprocessing; +using namespace CVC4::theory; +using namespace CVC4::kind; + +namespace CVC4 { +namespace smt { + +ExpandDefs::ExpandDefs(SmtEngine& smt, + ResourceManager& rm, + SmtEngineStatistics& stats) + : d_smt(smt), d_resourceManager(rm), d_smtStats(stats) +{ +} + +ExpandDefs::~ExpandDefs() {} + +Node ExpandDefs::expandDefinitions( + TNode n, + std::unordered_map& cache, + bool expandOnly) +{ + NodeManager* nm = d_smt.getNodeManager(); + std::stack> worklist; + std::stack result; + worklist.push(std::make_tuple(Node(n), Node(n), false)); + // The worklist is made of triples, each is input / original node then the + // output / rewritten node and finally a flag tracking whether the children + // have been explored (i.e. if this is a downward or upward pass). + + do + { + d_resourceManager.spendResource(ResourceManager::Resource::PreprocessStep); + + // n is the input / original + // node is the output / result + Node node; + bool childrenPushed; + std::tie(n, node, childrenPushed) = worklist.top(); + worklist.pop(); + + // Working downwards + if (!childrenPushed) + { + Kind k = n.getKind(); + + // we can short circuit (variable) leaves + if (n.isVar()) + { + SmtEngine::DefinedFunctionMap* dfuns = d_smt.getDefinedFunctionMap(); + SmtEngine::DefinedFunctionMap::const_iterator i = dfuns->find(n); + if (i != dfuns->end()) + { + Node f = (*i).second.getFormula(); + // must expand its definition + Node fe = expandDefinitions(f, cache, expandOnly); + // replacement must be closed + if ((*i).second.getFormals().size() > 0) + { + result.push( + nm->mkNode(LAMBDA, + nm->mkNode(BOUND_VAR_LIST, (*i).second.getFormals()), + fe)); + continue; + } + // don't bother putting in the cache + result.push(fe); + continue; + } + // don't bother putting in the cache + result.push(n); + continue; + } + + // maybe it's in the cache + std::unordered_map::iterator cacheHit = + cache.find(n); + if (cacheHit != cache.end()) + { + TNode ret = (*cacheHit).second; + result.push(ret.isNull() ? n : ret); + continue; + } + + // otherwise expand it + bool doExpand = false; + if (k == APPLY_UF) + { + // Always do beta-reduction here. The reason is that there may be + // operators such as INTS_MODULUS in the body of the lambda that would + // otherwise be introduced by beta-reduction via the rewriter, but are + // not expanded here since the traversal in this function does not + // traverse the operators of nodes. Hence, we beta-reduce here to + // ensure terms in the body of the lambda are expanded during this + // call. + if (n.getOperator().getKind() == LAMBDA) + { + doExpand = true; + } + else + { + // We always check if this operator corresponds to a defined function. + doExpand = d_smt.isDefinedFunction(n.getOperator().toExpr()); + } + } + if (doExpand) + { + std::vector formals; + TNode fm; + if (n.getOperator().getKind() == LAMBDA) + { + TNode op = n.getOperator(); + // lambda + for (unsigned i = 0; i < op[0].getNumChildren(); i++) + { + formals.push_back(op[0][i]); + } + fm = op[1]; + } + else + { + // application of a user-defined symbol + TNode func = n.getOperator(); + SmtEngine::DefinedFunctionMap* dfuns = d_smt.getDefinedFunctionMap(); + SmtEngine::DefinedFunctionMap::const_iterator i = dfuns->find(func); + if (i == dfuns->end()) + { + throw TypeCheckingException( + n.toExpr(), + std::string("Undefined function: `") + func.toString() + "'"); + } + DefinedFunction def = (*i).second; + formals = def.getFormals(); + + if (Debug.isOn("expand")) + { + Debug("expand") << "found: " << n << std::endl; + Debug("expand") << " func: " << func << std::endl; + std::string name = func.getAttribute(expr::VarNameAttr()); + Debug("expand") << " : \"" << name << "\"" << std::endl; + } + if (Debug.isOn("expand")) + { + Debug("expand") << " defn: " << def.getFunction() << std::endl + << " ["; + if (formals.size() > 0) + { + copy(formals.begin(), + formals.end() - 1, + std::ostream_iterator(Debug("expand"), ", ")); + Debug("expand") << formals.back(); + } + Debug("expand") + << "]" << std::endl + << " " << def.getFunction().getType() << std::endl + << " " << def.getFormula() << std::endl; + } + + fm = def.getFormula(); + } + + Node instance = fm.substitute(formals.begin(), + formals.end(), + n.begin(), + n.begin() + formals.size()); + Debug("expand") << "made : " << instance << std::endl; + + Node expanded = expandDefinitions(instance, cache, expandOnly); + cache[n] = (n == expanded ? Node::null() : expanded); + result.push(expanded); + continue; + } + else if (!expandOnly) + { + // do not do any theory stuff if expandOnly is true + + theory::Theory* t = d_smt.getTheoryEngine()->theoryOf(node); + + Assert(t != NULL); + TrustNode trn = t->expandDefinition(n); + node = trn.isNull() ? Node(n) : trn.getNode(); + } + + // the partial functions can fall through, in which case we still + // consider their children + worklist.push(std::make_tuple( + Node(n), node, true)); // Original and rewritten result + + for (size_t i = 0; i < node.getNumChildren(); ++i) + { + worklist.push( + std::make_tuple(node[i], + node[i], + false)); // Rewrite the children of the result only + } + } + else + { + // Working upwards + // Reconstruct the node from it's (now rewritten) children on the stack + + Debug("expand") << "cons : " << node << std::endl; + if (node.getNumChildren() > 0) + { + // cout << "cons : " << node << std::endl; + NodeBuilder<> nb(node.getKind()); + if (node.getMetaKind() == metakind::PARAMETERIZED) + { + Debug("expand") << "op : " << node.getOperator() << std::endl; + // cout << "op : " << node.getOperator() << std::endl; + nb << node.getOperator(); + } + for (size_t i = 0, nchild = node.getNumChildren(); i < nchild; ++i) + { + Assert(!result.empty()); + Node expanded = result.top(); + result.pop(); + // cout << "exchld : " << expanded << std::endl; + Debug("expand") << "exchld : " << expanded << std::endl; + nb << expanded; + } + node = nb; + } + // Only cache once all subterms are expanded + cache[n] = n == node ? Node::null() : node; + result.push(node); + } + } while (!worklist.empty()); + + AlwaysAssert(result.size() == 1); + + return result.top(); +} + +void ExpandDefs::expandAssertions(AssertionPipeline& assertions, + bool expandOnly) +{ + Chat() << "expanding definitions in assertions..." << std::endl; + Trace("simplify") << "ExpandDefs::simplify(): expanding definitions" + << std::endl; + TimerStat::CodeTimer codeTimer(d_smtStats.d_definitionExpansionTime); + std::unordered_map cache; + for (size_t i = 0, nasserts = assertions.size(); i < nasserts; ++i) + { + Node assert = assertions[i]; + Node expd = expandDefinitions(assert, cache, expandOnly); + if (expd != assert) + { + assertions.replace(i, expd); + } + } +} + +} // namespace smt +} // namespace CVC4 diff --git a/src/smt/expand_definitions.h b/src/smt/expand_definitions.h new file mode 100644 index 000000000..f40ee4a4e --- /dev/null +++ b/src/smt/expand_definitions.h @@ -0,0 +1,76 @@ +/********************* */ +/*! \file process_assertions.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds, Tim King, Morgan Deters + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The module for processing assertions for an SMT engine. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__SMT__EXPAND_DEFINITIONS_H +#define CVC4__SMT__EXPAND_DEFINITIONS_H + +#include + +#include "expr/node.h" +#include "preprocessing/assertion_pipeline.h" +#include "smt/smt_engine_stats.h" +#include "util/resource_manager.h" + +namespace CVC4 { + +class SmtEngine; + +namespace smt { + +/** + * Module in charge of expanding definitions for an SMT engine. + * + * Its main features is expandDefinitions(TNode, ...), which returns the + * expanded formula of a term. + */ +class ExpandDefs +{ + public: + ExpandDefs(SmtEngine& smt, ResourceManager& rm, SmtEngineStatistics& stats); + ~ExpandDefs(); + /** + * Expand definitions in term n. Return the expanded form of n. + * + * @param n The node to expand + * @param cache Cache of previous results + * @param expandOnly if true, then the expandDefinitions function of + * TheoryEngine is not called on subterms of n. + * @return The expanded term. + */ + Node expandDefinitions( + TNode n, + std::unordered_map& cache, + bool expandOnly = false); + /** + * Expand defintitions in assertions. This applies this above method to each + * assertion in the given pipeline. + */ + void expandAssertions(preprocessing::AssertionPipeline& assertions, + bool expandOnly = false); + + private: + /** Reference to the SMT engine */ + SmtEngine& d_smt; + /** Reference to resource manager */ + ResourceManager& d_resourceManager; + /** Reference to the SMT stats */ + SmtEngineStatistics& d_smtStats; +}; + +} // namespace smt +} // namespace CVC4 + +#endif diff --git a/src/smt/preprocessor.cpp b/src/smt/preprocessor.cpp index 5a1ce63a4..2c8592657 100644 --- a/src/smt/preprocessor.cpp +++ b/src/smt/preprocessor.cpp @@ -36,7 +36,8 @@ Preprocessor::Preprocessor(SmtEngine& smt, d_absValues(abs), d_propagator(true, true), d_assertionsProcessed(u, false), - d_processor(smt, *smt.getResourceManager(), stats), + d_exDefs(smt, *smt.getResourceManager(), stats), + d_processor(smt, d_exDefs, *smt.getResourceManager(), stats), d_rtf(u), d_pnm(nullptr) { @@ -108,7 +109,7 @@ RemoveTermFormulas& Preprocessor::getTermFormulaRemover() { return d_rtf; } Node Preprocessor::expandDefinitions(const Node& n, bool expandOnly) { std::unordered_map cache; - return expandDefinitions(n, cache, expandOnly); + return d_exDefs.expandDefinitions(n, cache, expandOnly); } Node Preprocessor::expandDefinitions( @@ -125,7 +126,7 @@ Node Preprocessor::expandDefinitions( n.getType(true); } // expand only = true - return d_processor.expandDefinitions(n, cache, expandOnly); + return d_exDefs.expandDefinitions(n, cache, expandOnly); } Node Preprocessor::simplify(const Node& node, bool removeItes) @@ -143,7 +144,7 @@ Node Preprocessor::simplify(const Node& node, bool removeItes) nas.getType(true); } std::unordered_map cache; - Node n = d_processor.expandDefinitions(nas, cache); + Node n = d_exDefs.expandDefinitions(nas, cache); TrustNode ts = d_ppContext->getTopLevelSubstitutions().apply(n); Node ns = ts.isNull() ? n : ts.getNode(); if (removeItes) diff --git a/src/smt/preprocessor.h b/src/smt/preprocessor.h index cb83f969e..220a433fe 100644 --- a/src/smt/preprocessor.h +++ b/src/smt/preprocessor.h @@ -20,6 +20,7 @@ #include #include "preprocessing/preprocessing_pass_context.h" +#include "smt/expand_definitions.h" #include "smt/process_assertions.h" #include "smt/term_formula_removal.h" #include "theory/booleans/circuit_propagator.h" @@ -125,6 +126,8 @@ class Preprocessor context::CDO d_assertionsProcessed; /** The preprocessing pass context */ std::unique_ptr d_ppContext; + /** Expand definitions module, responsible for expanding definitions */ + ExpandDefs d_exDefs; /** * Process assertions module, responsible for implementing the preprocessing * passes. diff --git a/src/smt/process_assertions.cpp b/src/smt/process_assertions.cpp index 2011e7b83..c68b73336 100644 --- a/src/smt/process_assertions.cpp +++ b/src/smt/process_assertions.cpp @@ -53,9 +53,11 @@ class ScopeCounter }; ProcessAssertions::ProcessAssertions(SmtEngine& smt, + ExpandDefs& exDefs, ResourceManager& rm, SmtEngineStatistics& stats) : d_smt(smt), + d_exDefs(exDefs), d_resourceManager(rm), d_smtStats(stats), d_preprocessingPassContext(nullptr) @@ -128,21 +130,7 @@ bool ProcessAssertions::apply(Assertions& as) << "ProcessAssertions::processAssertions() : pre-definition-expansion" << endl; dumpAssertions("pre-definition-expansion", assertions); - { - Chat() << "expanding definitions..." << endl; - Trace("simplify") << "ProcessAssertions::simplify(): expanding definitions" - << endl; - TimerStat::CodeTimer codeTimer(d_smtStats.d_definitionExpansionTime); - unordered_map cache; - for (size_t i = 0, nasserts = assertions.size(); i < nasserts; ++i) - { - Node expd = expandDefinitions(assertions[i], cache); - if (expd != assertions[i]) - { - assertions.replace(i, expd); - } - } - } + d_exDefs.expandAssertions(assertions, false); Trace("smt-proc") << "ProcessAssertions::processAssertions() : post-definition-expansion" << endl; @@ -550,222 +538,6 @@ void ProcessAssertions::dumpAssertions(const char* key, } } -Node ProcessAssertions::expandDefinitions( - TNode n, - unordered_map& cache, - bool expandOnly) -{ - NodeManager* nm = d_smt.getNodeManager(); - std::stack> worklist; - std::stack result; - worklist.push(std::make_tuple(Node(n), Node(n), false)); - // The worklist is made of triples, each is input / original node then the - // output / rewritten node and finally a flag tracking whether the children - // have been explored (i.e. if this is a downward or upward pass). - - do - { - spendResource(ResourceManager::Resource::PreprocessStep); - - // n is the input / original - // node is the output / result - Node node; - bool childrenPushed; - std::tie(n, node, childrenPushed) = worklist.top(); - worklist.pop(); - - // Working downwards - if (!childrenPushed) - { - Kind k = n.getKind(); - - // we can short circuit (variable) leaves - if (n.isVar()) - { - SmtEngine::DefinedFunctionMap* dfuns = d_smt.getDefinedFunctionMap(); - SmtEngine::DefinedFunctionMap::const_iterator i = dfuns->find(n); - if (i != dfuns->end()) - { - Node f = (*i).second.getFormula(); - // must expand its definition - Node fe = expandDefinitions(f, cache, expandOnly); - // replacement must be closed - if ((*i).second.getFormals().size() > 0) - { - result.push( - nm->mkNode(LAMBDA, - nm->mkNode(BOUND_VAR_LIST, (*i).second.getFormals()), - fe)); - continue; - } - // don't bother putting in the cache - result.push(fe); - continue; - } - // don't bother putting in the cache - result.push(n); - continue; - } - - // maybe it's in the cache - unordered_map::iterator cacheHit = - cache.find(n); - if (cacheHit != cache.end()) - { - TNode ret = (*cacheHit).second; - result.push(ret.isNull() ? n : ret); - continue; - } - - // otherwise expand it - bool doExpand = false; - if (k == APPLY_UF) - { - // Always do beta-reduction here. The reason is that there may be - // operators such as INTS_MODULUS in the body of the lambda that would - // otherwise be introduced by beta-reduction via the rewriter, but are - // not expanded here since the traversal in this function does not - // traverse the operators of nodes. Hence, we beta-reduce here to - // ensure terms in the body of the lambda are expanded during this - // call. - if (n.getOperator().getKind() == LAMBDA) - { - doExpand = true; - } - else - { - // We always check if this operator corresponds to a defined function. - doExpand = d_smt.isDefinedFunction(n.getOperator().toExpr()); - } - } - if (doExpand) - { - vector formals; - TNode fm; - if (n.getOperator().getKind() == LAMBDA) - { - TNode op = n.getOperator(); - // lambda - for (unsigned i = 0; i < op[0].getNumChildren(); i++) - { - formals.push_back(op[0][i]); - } - fm = op[1]; - } - else - { - // application of a user-defined symbol - TNode func = n.getOperator(); - SmtEngine::DefinedFunctionMap* dfuns = d_smt.getDefinedFunctionMap(); - SmtEngine::DefinedFunctionMap::const_iterator i = dfuns->find(func); - if (i == dfuns->end()) - { - throw TypeCheckingException( - n.toExpr(), - string("Undefined function: `") + func.toString() + "'"); - } - DefinedFunction def = (*i).second; - formals = def.getFormals(); - - if (Debug.isOn("expand")) - { - Debug("expand") << "found: " << n << endl; - Debug("expand") << " func: " << func << endl; - string name = func.getAttribute(expr::VarNameAttr()); - Debug("expand") << " : \"" << name << "\"" << endl; - } - if (Debug.isOn("expand")) - { - Debug("expand") << " defn: " << def.getFunction() << endl - << " ["; - if (formals.size() > 0) - { - copy(formals.begin(), - formals.end() - 1, - ostream_iterator(Debug("expand"), ", ")); - Debug("expand") << formals.back(); - } - Debug("expand") << "]" << endl - << " " << def.getFunction().getType() << endl - << " " << def.getFormula() << endl; - } - - fm = def.getFormula(); - } - - Node instance = fm.substitute(formals.begin(), - formals.end(), - n.begin(), - n.begin() + formals.size()); - Debug("expand") << "made : " << instance << endl; - - Node expanded = expandDefinitions(instance, cache, expandOnly); - cache[n] = (n == expanded ? Node::null() : expanded); - result.push(expanded); - continue; - } - else if (!expandOnly) - { - // do not do any theory stuff if expandOnly is true - - theory::Theory* t = d_smt.getTheoryEngine()->theoryOf(node); - - Assert(t != NULL); - TrustNode trn = t->expandDefinition(n); - node = trn.isNull() ? Node(n) : trn.getNode(); - } - - // the partial functions can fall through, in which case we still - // consider their children - worklist.push(std::make_tuple( - Node(n), node, true)); // Original and rewritten result - - for (size_t i = 0; i < node.getNumChildren(); ++i) - { - worklist.push( - std::make_tuple(node[i], - node[i], - false)); // Rewrite the children of the result only - } - } - else - { - // Working upwards - // Reconstruct the node from it's (now rewritten) children on the stack - - Debug("expand") << "cons : " << node << endl; - if (node.getNumChildren() > 0) - { - // cout << "cons : " << node << endl; - NodeBuilder<> nb(node.getKind()); - if (node.getMetaKind() == metakind::PARAMETERIZED) - { - Debug("expand") << "op : " << node.getOperator() << endl; - // cout << "op : " << node.getOperator() << endl; - nb << node.getOperator(); - } - for (size_t i = 0, nchild = node.getNumChildren(); i < nchild; ++i) - { - Assert(!result.empty()); - Node expanded = result.top(); - result.pop(); - // cout << "exchld : " << expanded << endl; - Debug("expand") << "exchld : " << expanded << endl; - nb << expanded; - } - node = nb; - } - // Only cache once all subterms are expanded - cache[n] = n == node ? Node::null() : node; - result.push(node); - } - } while (!worklist.empty()); - - AlwaysAssert(result.size() == 1); - - return result.top(); -} - void ProcessAssertions::collectSkolems( IteSkolemMap& iskMap, TNode n, diff --git a/src/smt/process_assertions.h b/src/smt/process_assertions.h index d260edf14..072603e7d 100644 --- a/src/smt/process_assertions.h +++ b/src/smt/process_assertions.h @@ -26,6 +26,7 @@ #include "preprocessing/preprocessing_pass.h" #include "preprocessing/preprocessing_pass_context.h" #include "smt/assertions.h" +#include "smt/expand_definitions.h" #include "smt/smt_engine_stats.h" #include "util/resource_manager.h" @@ -53,11 +54,11 @@ class ProcessAssertions { /** The types for the recursive function definitions */ typedef context::CDList NodeList; - typedef unordered_map NodeToNodeHashMap; typedef unordered_map NodeToBoolHashMap; public: ProcessAssertions(SmtEngine& smt, + ExpandDefs& exDefs, ResourceManager& rm, SmtEngineStatistics& stats); ~ProcessAssertions(); @@ -76,22 +77,12 @@ class ProcessAssertions * processing the assertions. */ bool apply(Assertions& as); - /** - * Expand definitions in term n. Return the expanded form of n. - * - * @param n The node to expand - * @param cache Cache of previous results - * @param expandOnly if true, then the expandDefinitions function of - * TheoryEngine is not called on subterms of n. - * @return The expanded term. - */ - Node expandDefinitions(TNode n, - NodeToNodeHashMap& cache, - bool expandOnly = false); private: /** Reference to the SMT engine */ SmtEngine& d_smt; + /** Reference to expand definitions module */ + ExpandDefs& d_exDefs; /** Reference to resource manager */ ResourceManager& d_resourceManager; /** Reference to the SMT stats */ -- cgit v1.2.3