diff options
Diffstat (limited to 'src/smt')
-rw-r--r-- | src/smt/check_models.cpp | 14 | ||||
-rw-r--r-- | src/smt/command.cpp | 41 | ||||
-rw-r--r-- | src/smt/command.h | 5 | ||||
-rw-r--r-- | src/smt/dump_manager.cpp | 76 | ||||
-rw-r--r-- | src/smt/dump_manager.h | 30 | ||||
-rw-r--r-- | src/smt/expand_definitions.cpp | 278 | ||||
-rw-r--r-- | src/smt/expand_definitions.h | 76 | ||||
-rw-r--r-- | src/smt/model.cpp | 28 | ||||
-rw-r--r-- | src/smt/model.h | 44 | ||||
-rw-r--r-- | src/smt/node_command.cpp | 23 | ||||
-rw-r--r-- | src/smt/node_command.h | 5 | ||||
-rw-r--r-- | src/smt/preprocessor.cpp | 10 | ||||
-rw-r--r-- | src/smt/preprocessor.h | 3 | ||||
-rw-r--r-- | src/smt/process_assertions.cpp | 234 | ||||
-rw-r--r-- | src/smt/process_assertions.h | 17 | ||||
-rw-r--r-- | src/smt/smt_engine.cpp | 42 | ||||
-rw-r--r-- | src/smt/smt_engine.h | 22 |
17 files changed, 472 insertions, 476 deletions
diff --git a/src/smt/check_models.cpp b/src/smt/check_models.cpp index 612084de2..56d54eec9 100644 --- a/src/smt/check_models.cpp +++ b/src/smt/check_models.cpp @@ -67,24 +67,14 @@ void CheckModels::checkModel(Model* m, /* substituteUnderQuantifiers = */ false); Trace("check-model") << "checkModel: Collect substitution..." << std::endl; - for (size_t k = 0, ncmd = m->getNumCommands(); k < ncmd; ++k) + const std::vector<Node>& decTerms = m->getDeclaredTerms(); + for (const Node& func : decTerms) { - const DeclareFunctionNodeCommand* c = - dynamic_cast<const DeclareFunctionNodeCommand*>(m->getCommand(k)); - Notice() << "SmtEngine::checkModel(): model command " << k << " : " - << m->getCommand(k)->toString() << std::endl; - if (c == nullptr) - { - // we don't care about DECLARE-DATATYPES, DECLARE-SORT, ... - Notice() << "SmtEngine::checkModel(): skipping..." << std::endl; - continue; - } // We have a DECLARE-FUN: // // We'll first do some checks, then add to our substitution map // the mapping: function symbol |-> value - Node func = c->getFunction(); Node val = m->getValue(func); Notice() << "SmtEngine::checkModel(): adding substitution: " << func diff --git a/src/smt/command.cpp b/src/smt/command.cpp index 717d423fe..e6361be9e 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -1069,28 +1069,17 @@ DeclareFunctionCommand::DeclareFunctionCommand(const std::string& id, api::Sort sort) : DeclarationDefinitionCommand(id), d_func(func), - d_sort(sort), - d_printInModel(true), - d_printInModelSetByUser(false) + d_sort(sort) { } api::Term DeclareFunctionCommand::getFunction() const { return d_func; } api::Sort DeclareFunctionCommand::getSort() const { return d_sort; } -bool DeclareFunctionCommand::getPrintInModel() const { return d_printInModel; } -bool DeclareFunctionCommand::getPrintInModelSetByUser() const -{ - return d_printInModelSetByUser; -} - -void DeclareFunctionCommand::setPrintInModel(bool p) -{ - d_printInModel = p; - d_printInModelSetByUser = true; -} void DeclareFunctionCommand::invoke(api::Solver* solver, SymbolManager* sm) { + // mark that it will be printed in the model + sm->addModelDeclarationTerm(d_func); d_commandStatus = CommandSuccess::instance(); } @@ -1098,8 +1087,6 @@ Command* DeclareFunctionCommand::clone() const { DeclareFunctionCommand* dfc = new DeclareFunctionCommand(d_symbol, d_func, d_sort); - dfc->d_printInModel = d_printInModel; - dfc->d_printInModelSetByUser = d_printInModelSetByUser; return dfc; } @@ -1132,6 +1119,8 @@ size_t DeclareSortCommand::getArity() const { return d_arity; } api::Sort DeclareSortCommand::getSort() const { return d_sort; } void DeclareSortCommand::invoke(api::Solver* solver, SymbolManager* sm) { + // mark that it will be printed in the model + sm->addModelDeclarationSort(d_sort); d_commandStatus = CommandSuccess::instance(); } @@ -1150,8 +1139,8 @@ void DeclareSortCommand::toStream(std::ostream& out, size_t dag, OutputLanguage language) const { - Printer::getPrinter(language)->toStreamCmdDeclareType( - out, d_sort.toString(), d_arity, d_sort.getTypeNode()); + Printer::getPrinter(language)->toStreamCmdDeclareType(out, + d_sort.getTypeNode()); } /* -------------------------------------------------------------------------- */ @@ -1438,8 +1427,8 @@ void SetUserAttributeCommand::invoke(api::Solver* solver, SymbolManager* sm) { solver->getSmtEngine()->setUserAttribute( d_attr, - d_term.getExpr(), - api::termVectorToExprs(d_termValues), + d_term.getNode(), + api::termVectorToNodes(d_termValues), d_strValue); } d_commandStatus = CommandSuccess::instance(); @@ -1693,6 +1682,18 @@ void GetModelCommand::invoke(api::Solver* solver, SymbolManager* sm) try { d_result = solver->getSmtEngine()->getModel(); + // set the model declarations, which determines what is printed in the model + d_result->clearModelDeclarations(); + std::vector<api::Sort> declareSorts = sm->getModelDeclareSorts(); + for (const api::Sort& s : declareSorts) + { + d_result->addDeclarationSort(s.getTypeNode()); + } + std::vector<api::Term> declareTerms = sm->getModelDeclareTerms(); + for (const api::Term& t : declareTerms) + { + d_result->addDeclarationTerm(t.getNode()); + } d_commandStatus = CommandSuccess::instance(); } catch (RecoverableModalException& e) diff --git a/src/smt/command.h b/src/smt/command.h index 96a6938d6..0b86f3539 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -387,16 +387,11 @@ class CVC4_PUBLIC DeclareFunctionCommand : public DeclarationDefinitionCommand protected: api::Term d_func; api::Sort d_sort; - bool d_printInModel; - bool d_printInModelSetByUser; public: DeclareFunctionCommand(const std::string& id, api::Term func, api::Sort sort); api::Term getFunction() const; api::Sort getSort() const; - bool getPrintInModel() const; - bool getPrintInModelSetByUser() const; - void setPrintInModel(bool p); void invoke(api::Solver* solver, SymbolManager* sm) override; Command* clone() const override; diff --git a/src/smt/dump_manager.cpp b/src/smt/dump_manager.cpp index 9b7fba5a2..9d3031b4d 100644 --- a/src/smt/dump_manager.cpp +++ b/src/smt/dump_manager.cpp @@ -24,8 +24,6 @@ namespace smt { DumpManager::DumpManager(context::UserContext* u) : d_fullyInited(false), - d_modelGlobalCommands(), - d_modelCommands(u), d_dumpCommands() { } @@ -33,8 +31,6 @@ DumpManager::DumpManager(context::UserContext* u) DumpManager::~DumpManager() { d_dumpCommands.clear(); - d_modelCommandsAlloc.clear(); - d_modelGlobalCommands.clear(); } void DumpManager::finishInit() @@ -49,8 +45,10 @@ void DumpManager::finishInit() d_fullyInited = true; } - -void DumpManager::resetAssertions() { d_modelGlobalCommands.clear(); } +void DumpManager::resetAssertions() +{ + // currently, do nothing +} void DumpManager::addToModelCommandAndDump(const NodeCommand& c, uint32_t flags, @@ -58,29 +56,6 @@ void DumpManager::addToModelCommandAndDump(const NodeCommand& c, const char* dumpTag) { Trace("smt") << "SMT addToModelCommandAndDump(" << c << ")" << std::endl; - // If we aren't yet fully inited, the user might still turn on - // produce-models. So let's keep any commands around just in - // case. This is useful in two cases: (1) SMT-LIBv1 auto-declares - // sort "U" in QF_UF before setLogic() is run and we still want to - // support finding card(U) with --finite-model-find, and (2) to - // decouple SmtEngine and ExprManager if the user does a few - // ExprManager::mkSort() before SmtEngine::setOption("produce-models") - // and expects to find their cardinalities in the model. - if ((!d_fullyInited || options::produceModels()) - && (flags & ExprManager::VAR_FLAG_DEFINED) == 0) - { - if (flags & ExprManager::VAR_FLAG_GLOBAL) - { - d_modelGlobalCommands.push_back(std::unique_ptr<NodeCommand>(c.clone())); - } - else - { - NodeCommand* cc = c.clone(); - d_modelCommands.push_back(cc); - // also remember for memory management purposes - d_modelCommandsAlloc.push_back(std::unique_ptr<NodeCommand>(cc)); - } - } if (Dump.isOn(dumpTag)) { if (d_fullyInited) @@ -97,48 +72,7 @@ void DumpManager::addToModelCommandAndDump(const NodeCommand& c, void DumpManager::setPrintFuncInModel(Node f, bool p) { Trace("setp-model") << "Set printInModel " << f << " to " << p << std::endl; - for (std::unique_ptr<NodeCommand>& c : d_modelGlobalCommands) - { - DeclareFunctionNodeCommand* dfc = - dynamic_cast<DeclareFunctionNodeCommand*>(c.get()); - if (dfc != NULL) - { - Node df = dfc->getFunction(); - if (df == f) - { - dfc->setPrintInModel(p); - } - } - } - for (NodeCommand* c : d_modelCommands) - { - DeclareFunctionNodeCommand* dfc = - dynamic_cast<DeclareFunctionNodeCommand*>(c); - if (dfc != NULL) - { - Node df = dfc->getFunction(); - if (df == f) - { - dfc->setPrintInModel(p); - } - } - } -} - -size_t DumpManager::getNumModelCommands() const -{ - return d_modelCommands.size() + d_modelGlobalCommands.size(); -} - -const NodeCommand* DumpManager::getModelCommand(size_t i) const -{ - Assert(i < getNumModelCommands()); - // index the global commands first, then the locals - if (i < d_modelGlobalCommands.size()) - { - return d_modelGlobalCommands[i].get(); - } - return d_modelCommands[i - d_modelGlobalCommands.size()]; + // TODO (cvc4-wishues/issues/75): implement } } // namespace smt diff --git a/src/smt/dump_manager.h b/src/smt/dump_manager.h index 0ba8e0b8b..eaedf39a1 100644 --- a/src/smt/dump_manager.h +++ b/src/smt/dump_manager.h @@ -31,14 +31,10 @@ namespace smt { /** * This utility is responsible for: - * (1) storing information required for SMT-LIB queries such as get-model, - * which requires knowing what symbols are declared and should be printed in - * the model. - * (2) implementing some dumping traces e.g. --dump=declarations. + * implementing some dumping traces e.g. --dump=declarations. */ class DumpManager { - typedef context::CDList<NodeCommand*> CommandList; public: DumpManager(context::UserContext* u); @@ -65,34 +61,10 @@ class DumpManager * Set that function f should print in the model if and only if p is true. */ void setPrintFuncInModel(Node f, bool p); - /** get number of commands to report in a model */ - size_t getNumModelCommands() const; - /** get model command at index i */ - const NodeCommand* getModelCommand(size_t i) const; private: /** Fully inited */ bool d_fullyInited; - - /** - * A list of commands that should be in the Model globally (i.e., - * regardless of push/pop). Only maintained if produce-models option - * is on. - */ - std::vector<std::unique_ptr<NodeCommand>> d_modelGlobalCommands; - - /** - * A list of commands that should be in the Model locally (i.e., - * it is context-dependent on push/pop). Only maintained if - * produce-models option is on. - */ - CommandList d_modelCommands; - /** - * A list of model commands allocated to d_modelCommands at any time. This - * is maintained for memory management purposes. - */ - std::vector<std::unique_ptr<NodeCommand>> d_modelCommandsAlloc; - /** * A vector of declaration commands waiting to be dumped out. * Once the SmtEngine is fully initialized, we'll dump them. diff --git a/src/smt/expand_definitions.cpp b/src/smt/expand_definitions.cpp new file mode 100644 index 000000000..20c4f8ef6 --- /dev/null +++ b/src/smt/expand_definitions.cpp @@ -0,0 +1,278 @@ +/********************* */ +/*! \file expand_definitions.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds, Tim King, Haniel Barbosa + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of expand definitions for an SMT engine. + **/ + +#include "smt/expand_definitions.h" + +#include <stack> +#include <utility> + +#include "expr/node_manager_attributes.h" +#include "smt/defined_function.h" +#include "smt/smt_engine.h" +#include "theory/theory_engine.h" + +using namespace CVC4::preprocessing; +using namespace CVC4::theory; +using namespace CVC4::kind; + +namespace CVC4 { +namespace smt { + +ExpandDefs::ExpandDefs(SmtEngine& smt, + ResourceManager& rm, + SmtEngineStatistics& stats) + : d_smt(smt), d_resourceManager(rm), d_smtStats(stats) +{ +} + +ExpandDefs::~ExpandDefs() {} + +Node ExpandDefs::expandDefinitions( + TNode n, + std::unordered_map<Node, Node, NodeHashFunction>& cache, + bool expandOnly) +{ + NodeManager* nm = d_smt.getNodeManager(); + std::stack<std::tuple<Node, Node, bool>> worklist; + std::stack<Node> result; + worklist.push(std::make_tuple(Node(n), Node(n), false)); + // The worklist is made of triples, each is input / original node then the + // output / rewritten node and finally a flag tracking whether the children + // have been explored (i.e. if this is a downward or upward pass). + + do + { + d_resourceManager.spendResource(ResourceManager::Resource::PreprocessStep); + + // n is the input / original + // node is the output / result + Node node; + bool childrenPushed; + std::tie(n, node, childrenPushed) = worklist.top(); + worklist.pop(); + + // Working downwards + if (!childrenPushed) + { + Kind k = n.getKind(); + + // we can short circuit (variable) leaves + if (n.isVar()) + { + SmtEngine::DefinedFunctionMap* dfuns = d_smt.getDefinedFunctionMap(); + SmtEngine::DefinedFunctionMap::const_iterator i = dfuns->find(n); + if (i != dfuns->end()) + { + Node f = (*i).second.getFormula(); + // must expand its definition + Node fe = expandDefinitions(f, cache, expandOnly); + // replacement must be closed + if ((*i).second.getFormals().size() > 0) + { + result.push( + nm->mkNode(LAMBDA, + nm->mkNode(BOUND_VAR_LIST, (*i).second.getFormals()), + fe)); + continue; + } + // don't bother putting in the cache + result.push(fe); + continue; + } + // don't bother putting in the cache + result.push(n); + continue; + } + + // maybe it's in the cache + std::unordered_map<Node, Node, NodeHashFunction>::iterator cacheHit = + cache.find(n); + if (cacheHit != cache.end()) + { + TNode ret = (*cacheHit).second; + result.push(ret.isNull() ? n : ret); + continue; + } + + // otherwise expand it + bool doExpand = false; + if (k == APPLY_UF) + { + // Always do beta-reduction here. The reason is that there may be + // operators such as INTS_MODULUS in the body of the lambda that would + // otherwise be introduced by beta-reduction via the rewriter, but are + // not expanded here since the traversal in this function does not + // traverse the operators of nodes. Hence, we beta-reduce here to + // ensure terms in the body of the lambda are expanded during this + // call. + if (n.getOperator().getKind() == LAMBDA) + { + doExpand = true; + } + else + { + // We always check if this operator corresponds to a defined function. + doExpand = d_smt.isDefinedFunction(n.getOperator().toExpr()); + } + } + if (doExpand) + { + std::vector<Node> formals; + TNode fm; + if (n.getOperator().getKind() == LAMBDA) + { + TNode op = n.getOperator(); + // lambda + for (unsigned i = 0; i < op[0].getNumChildren(); i++) + { + formals.push_back(op[0][i]); + } + fm = op[1]; + } + else + { + // application of a user-defined symbol + TNode func = n.getOperator(); + SmtEngine::DefinedFunctionMap* dfuns = d_smt.getDefinedFunctionMap(); + SmtEngine::DefinedFunctionMap::const_iterator i = dfuns->find(func); + if (i == dfuns->end()) + { + throw TypeCheckingException( + n.toExpr(), + std::string("Undefined function: `") + func.toString() + "'"); + } + DefinedFunction def = (*i).second; + formals = def.getFormals(); + + if (Debug.isOn("expand")) + { + Debug("expand") << "found: " << n << std::endl; + Debug("expand") << " func: " << func << std::endl; + std::string name = func.getAttribute(expr::VarNameAttr()); + Debug("expand") << " : \"" << name << "\"" << std::endl; + } + if (Debug.isOn("expand")) + { + Debug("expand") << " defn: " << def.getFunction() << std::endl + << " ["; + if (formals.size() > 0) + { + copy(formals.begin(), + formals.end() - 1, + std::ostream_iterator<Node>(Debug("expand"), ", ")); + Debug("expand") << formals.back(); + } + Debug("expand") + << "]" << std::endl + << " " << def.getFunction().getType() << std::endl + << " " << def.getFormula() << std::endl; + } + + fm = def.getFormula(); + } + + Node instance = fm.substitute(formals.begin(), + formals.end(), + n.begin(), + n.begin() + formals.size()); + Debug("expand") << "made : " << instance << std::endl; + + Node expanded = expandDefinitions(instance, cache, expandOnly); + cache[n] = (n == expanded ? Node::null() : expanded); + result.push(expanded); + continue; + } + else if (!expandOnly) + { + // do not do any theory stuff if expandOnly is true + + theory::Theory* t = d_smt.getTheoryEngine()->theoryOf(node); + + Assert(t != NULL); + TrustNode trn = t->expandDefinition(n); + node = trn.isNull() ? Node(n) : trn.getNode(); + } + + // the partial functions can fall through, in which case we still + // consider their children + worklist.push(std::make_tuple( + Node(n), node, true)); // Original and rewritten result + + for (size_t i = 0; i < node.getNumChildren(); ++i) + { + worklist.push( + std::make_tuple(node[i], + node[i], + false)); // Rewrite the children of the result only + } + } + else + { + // Working upwards + // Reconstruct the node from it's (now rewritten) children on the stack + + Debug("expand") << "cons : " << node << std::endl; + if (node.getNumChildren() > 0) + { + // cout << "cons : " << node << std::endl; + NodeBuilder<> nb(node.getKind()); + if (node.getMetaKind() == metakind::PARAMETERIZED) + { + Debug("expand") << "op : " << node.getOperator() << std::endl; + // cout << "op : " << node.getOperator() << std::endl; + nb << node.getOperator(); + } + for (size_t i = 0, nchild = node.getNumChildren(); i < nchild; ++i) + { + Assert(!result.empty()); + Node expanded = result.top(); + result.pop(); + // cout << "exchld : " << expanded << std::endl; + Debug("expand") << "exchld : " << expanded << std::endl; + nb << expanded; + } + node = nb; + } + // Only cache once all subterms are expanded + cache[n] = n == node ? Node::null() : node; + result.push(node); + } + } while (!worklist.empty()); + + AlwaysAssert(result.size() == 1); + + return result.top(); +} + +void ExpandDefs::expandAssertions(AssertionPipeline& assertions, + bool expandOnly) +{ + Chat() << "expanding definitions in assertions..." << std::endl; + Trace("simplify") << "ExpandDefs::simplify(): expanding definitions" + << std::endl; + TimerStat::CodeTimer codeTimer(d_smtStats.d_definitionExpansionTime); + std::unordered_map<Node, Node, NodeHashFunction> cache; + for (size_t i = 0, nasserts = assertions.size(); i < nasserts; ++i) + { + Node assert = assertions[i]; + Node expd = expandDefinitions(assert, cache, expandOnly); + if (expd != assert) + { + assertions.replace(i, expd); + } + } +} + +} // namespace smt +} // namespace CVC4 diff --git a/src/smt/expand_definitions.h b/src/smt/expand_definitions.h new file mode 100644 index 000000000..f40ee4a4e --- /dev/null +++ b/src/smt/expand_definitions.h @@ -0,0 +1,76 @@ +/********************* */ +/*! \file process_assertions.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds, Tim King, Morgan Deters + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The module for processing assertions for an SMT engine. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__SMT__EXPAND_DEFINITIONS_H +#define CVC4__SMT__EXPAND_DEFINITIONS_H + +#include <unordered_map> + +#include "expr/node.h" +#include "preprocessing/assertion_pipeline.h" +#include "smt/smt_engine_stats.h" +#include "util/resource_manager.h" + +namespace CVC4 { + +class SmtEngine; + +namespace smt { + +/** + * Module in charge of expanding definitions for an SMT engine. + * + * Its main features is expandDefinitions(TNode, ...), which returns the + * expanded formula of a term. + */ +class ExpandDefs +{ + public: + ExpandDefs(SmtEngine& smt, ResourceManager& rm, SmtEngineStatistics& stats); + ~ExpandDefs(); + /** + * Expand definitions in term n. Return the expanded form of n. + * + * @param n The node to expand + * @param cache Cache of previous results + * @param expandOnly if true, then the expandDefinitions function of + * TheoryEngine is not called on subterms of n. + * @return The expanded term. + */ + Node expandDefinitions( + TNode n, + std::unordered_map<Node, Node, NodeHashFunction>& cache, + bool expandOnly = false); + /** + * Expand defintitions in assertions. This applies this above method to each + * assertion in the given pipeline. + */ + void expandAssertions(preprocessing::AssertionPipeline& assertions, + bool expandOnly = false); + + private: + /** Reference to the SMT engine */ + SmtEngine& d_smt; + /** Reference to resource manager */ + ResourceManager& d_resourceManager; + /** Reference to the SMT stats */ + SmtEngineStatistics& d_smtStats; +}; + +} // namespace smt +} // namespace CVC4 + +#endif diff --git a/src/smt/model.cpp b/src/smt/model.cpp index fc9ea8fbb..8a9f944d2 100644 --- a/src/smt/model.cpp +++ b/src/smt/model.cpp @@ -26,29 +26,17 @@ namespace CVC4 { namespace smt { -Model::Model(SmtEngine& smt, theory::TheoryModel* tm) - : d_smt(smt), d_isKnownSat(false), d_tmodel(tm) +Model::Model(theory::TheoryModel* tm) : d_isKnownSat(false), d_tmodel(tm) { Assert(d_tmodel != nullptr); } std::ostream& operator<<(std::ostream& out, const Model& m) { - smt::SmtScope smts(&m.d_smt); expr::ExprDag::Scope scope(out, false); Printer::getPrinter(options::outputLanguage())->toStream(out, m); return out; } -size_t Model::getNumCommands() const -{ - return d_smt.getDumpManager()->getNumModelCommands(); -} - -const NodeCommand* Model::getCommand(size_t i) const -{ - return d_smt.getDumpManager()->getModelCommand(i); -} - theory::TheoryModel* Model::getTheoryModel() { return d_tmodel; } const theory::TheoryModel* Model::getTheoryModel() const { return d_tmodel; } @@ -61,5 +49,19 @@ Node Model::getValue(TNode n) const { return d_tmodel->getValue(n); } bool Model::hasApproximations() const { return d_tmodel->hasApproximations(); } +void Model::clearModelDeclarations() { d_declareSorts.clear(); } + +void Model::addDeclarationSort(TypeNode tn) { d_declareSorts.push_back(tn); } + +void Model::addDeclarationTerm(Node n) { d_declareTerms.push_back(n); } +const std::vector<TypeNode>& Model::getDeclaredSorts() const +{ + return d_declareSorts; +} +const std::vector<Node>& Model::getDeclaredTerms() const +{ + return d_declareTerms; +} + } // namespace smt }/* CVC4 namespace */ diff --git a/src/smt/model.h b/src/smt/model.h index dc36b5d29..18675040a 100644 --- a/src/smt/model.h +++ b/src/smt/model.h @@ -27,7 +27,6 @@ namespace CVC4 { class SmtEngine; -class NodeCommand; namespace smt { @@ -39,6 +38,9 @@ std::ostream& operator<<(std::ostream&, const Model&); * This is the SMT-level model object, that is responsible for maintaining * the necessary information for how to print the model, as well as * holding a pointer to the underlying implementation of the theory model. + * + * The model declarations maintained by this class are context-independent + * and should be updated when this model is printed. */ class Model { friend std::ostream& operator<<(std::ostream&, const Model&); @@ -46,17 +48,9 @@ class Model { public: /** construct */ - Model(SmtEngine& smt, theory::TheoryModel* tm); + Model(theory::TheoryModel* tm); /** virtual destructor */ ~Model() {} - /** get number of commands to report */ - size_t getNumCommands() const; - /** get command */ - const NodeCommand* getCommand(size_t i) const; - /** get the smt engine that this model is hooked up to */ - SmtEngine* getSmtEngine() { return &d_smt; } - /** get the smt engine (as a pointer-to-const) that this model is hooked up to */ - const SmtEngine* getSmtEngine() const { return &d_smt; } /** get the input name (file name, etc.) this model is associated to */ std::string getInputName() const { return d_inputName; } /** @@ -78,9 +72,25 @@ class Model { /** Does this model have approximations? */ bool hasApproximations() const; //----------------------- end helper methods + //----------------------- model declarations + /** Clear the current model declarations. */ + void clearModelDeclarations(); + /** + * Set that tn is a sort that should be printed in the model, when applicable, + * based on the output language. + */ + void addDeclarationSort(TypeNode tn); + /** + * Set that n is a variable that should be printed in the model, when + * applicable, based on the output language. + */ + void addDeclarationTerm(Node n); + /** get declared sorts */ + const std::vector<TypeNode>& getDeclaredSorts() const; + /** get declared terms */ + const std::vector<Node>& getDeclaredTerms() const; + //----------------------- end model declarations protected: - /** The SmtEngine we're associated with */ - SmtEngine& d_smt; /** the input name (file name, etc.) this model is associated to */ std::string d_inputName; /** @@ -93,6 +103,16 @@ class Model { * the values of sorts and terms. */ theory::TheoryModel* d_tmodel; + /** + * The list of types to print, generally corresponding to declare-sort + * commands. + */ + std::vector<TypeNode> d_declareSorts; + /** + * The list of terms to print, is typically one-to-one with declare-fun + * commands. + */ + std::vector<Node> d_declareTerms; }; } // namespace smt diff --git a/src/smt/node_command.cpp b/src/smt/node_command.cpp index eb2493c87..815f99132 100644 --- a/src/smt/node_command.cpp +++ b/src/smt/node_command.cpp @@ -51,9 +51,7 @@ DeclareFunctionNodeCommand::DeclareFunctionNodeCommand(const std::string& id, TypeNode type) : d_id(id), d_fun(expr), - d_type(type), - d_printInModel(true), - d_printInModelSetByUser(false) + d_type(type) { } @@ -72,22 +70,6 @@ NodeCommand* DeclareFunctionNodeCommand::clone() const const Node& DeclareFunctionNodeCommand::getFunction() const { return d_fun; } -bool DeclareFunctionNodeCommand::getPrintInModel() const -{ - return d_printInModel; -} - -bool DeclareFunctionNodeCommand::getPrintInModelSetByUser() const -{ - return d_printInModelSetByUser; -} - -void DeclareFunctionNodeCommand::setPrintInModel(bool p) -{ - d_printInModel = p; - d_printInModelSetByUser = true; -} - /* -------------------------------------------------------------------------- */ /* class DeclareTypeNodeCommand */ /* -------------------------------------------------------------------------- */ @@ -104,8 +86,7 @@ void DeclareTypeNodeCommand::toStream(std::ostream& out, size_t dag, OutputLanguage language) const { - Printer::getPrinter(language)->toStreamCmdDeclareType( - out, d_id, d_arity, d_type); + Printer::getPrinter(language)->toStreamCmdDeclareType(out, d_type); } NodeCommand* DeclareTypeNodeCommand::clone() const diff --git a/src/smt/node_command.h b/src/smt/node_command.h index 8cf9a5e10..e1a15e62c 100644 --- a/src/smt/node_command.h +++ b/src/smt/node_command.h @@ -68,16 +68,11 @@ class DeclareFunctionNodeCommand : public NodeCommand OutputLanguage language = language::output::LANG_AUTO) const override; NodeCommand* clone() const override; const Node& getFunction() const; - bool getPrintInModel() const; - bool getPrintInModelSetByUser() const; - void setPrintInModel(bool p); private: std::string d_id; Node d_fun; TypeNode d_type; - bool d_printInModel; - bool d_printInModelSetByUser; }; /** diff --git a/src/smt/preprocessor.cpp b/src/smt/preprocessor.cpp index bf7009081..bd6988645 100644 --- a/src/smt/preprocessor.cpp +++ b/src/smt/preprocessor.cpp @@ -35,7 +35,8 @@ Preprocessor::Preprocessor(SmtEngine& smt, d_absValues(abs), d_propagator(true, true), d_assertionsProcessed(u, false), - d_processor(smt, *smt.getResourceManager(), stats), + d_exDefs(smt, *smt.getResourceManager(), stats), + d_processor(smt, d_exDefs, *smt.getResourceManager(), stats), d_rtf(u), d_pnm(nullptr) { @@ -107,7 +108,7 @@ RemoveTermFormulas& Preprocessor::getTermFormulaRemover() { return d_rtf; } Node Preprocessor::expandDefinitions(const Node& n, bool expandOnly) { std::unordered_map<Node, Node, NodeHashFunction> cache; - return expandDefinitions(n, cache, expandOnly); + return d_exDefs.expandDefinitions(n, cache, expandOnly); } Node Preprocessor::expandDefinitions( @@ -124,7 +125,7 @@ Node Preprocessor::expandDefinitions( n.getType(true); } // expand only = true - return d_processor.expandDefinitions(n, cache, expandOnly); + return d_exDefs.expandDefinitions(n, cache, expandOnly); } Node Preprocessor::simplify(const Node& node, bool removeItes) @@ -142,7 +143,7 @@ Node Preprocessor::simplify(const Node& node, bool removeItes) nas.getType(true); } std::unordered_map<Node, Node, NodeHashFunction> cache; - Node n = d_processor.expandDefinitions(nas, cache); + Node n = d_exDefs.expandDefinitions(nas, cache); TrustNode ts = d_ppContext->getTopLevelSubstitutions().apply(n); Node ns = ts.isNull() ? n : ts.getNode(); if (removeItes) @@ -157,6 +158,7 @@ void Preprocessor::setProofGenerator(PreprocessProofGenerator* pppg) { Assert(pppg != nullptr); d_pnm = pppg->getManager(); + d_propagator.setProof(d_pnm, d_context, pppg); d_rtf.setProofNodeManager(d_pnm); } diff --git a/src/smt/preprocessor.h b/src/smt/preprocessor.h index 8700c3885..696261d9e 100644 --- a/src/smt/preprocessor.h +++ b/src/smt/preprocessor.h @@ -20,6 +20,7 @@ #include <vector> #include "preprocessing/preprocessing_pass_context.h" +#include "smt/expand_definitions.h" #include "smt/process_assertions.h" #include "smt/term_formula_removal.h" #include "theory/booleans/circuit_propagator.h" @@ -123,6 +124,8 @@ class Preprocessor context::CDO<bool> d_assertionsProcessed; /** The preprocessing pass context */ std::unique_ptr<preprocessing::PreprocessingPassContext> d_ppContext; + /** Expand definitions module, responsible for expanding definitions */ + ExpandDefs d_exDefs; /** * Process assertions module, responsible for implementing the preprocessing * passes. diff --git a/src/smt/process_assertions.cpp b/src/smt/process_assertions.cpp index 2011e7b83..c68b73336 100644 --- a/src/smt/process_assertions.cpp +++ b/src/smt/process_assertions.cpp @@ -53,9 +53,11 @@ class ScopeCounter }; ProcessAssertions::ProcessAssertions(SmtEngine& smt, + ExpandDefs& exDefs, ResourceManager& rm, SmtEngineStatistics& stats) : d_smt(smt), + d_exDefs(exDefs), d_resourceManager(rm), d_smtStats(stats), d_preprocessingPassContext(nullptr) @@ -128,21 +130,7 @@ bool ProcessAssertions::apply(Assertions& as) << "ProcessAssertions::processAssertions() : pre-definition-expansion" << endl; dumpAssertions("pre-definition-expansion", assertions); - { - Chat() << "expanding definitions..." << endl; - Trace("simplify") << "ProcessAssertions::simplify(): expanding definitions" - << endl; - TimerStat::CodeTimer codeTimer(d_smtStats.d_definitionExpansionTime); - unordered_map<Node, Node, NodeHashFunction> cache; - for (size_t i = 0, nasserts = assertions.size(); i < nasserts; ++i) - { - Node expd = expandDefinitions(assertions[i], cache); - if (expd != assertions[i]) - { - assertions.replace(i, expd); - } - } - } + d_exDefs.expandAssertions(assertions, false); Trace("smt-proc") << "ProcessAssertions::processAssertions() : post-definition-expansion" << endl; @@ -550,222 +538,6 @@ void ProcessAssertions::dumpAssertions(const char* key, } } -Node ProcessAssertions::expandDefinitions( - TNode n, - unordered_map<Node, Node, NodeHashFunction>& cache, - bool expandOnly) -{ - NodeManager* nm = d_smt.getNodeManager(); - std::stack<std::tuple<Node, Node, bool>> worklist; - std::stack<Node> result; - worklist.push(std::make_tuple(Node(n), Node(n), false)); - // The worklist is made of triples, each is input / original node then the - // output / rewritten node and finally a flag tracking whether the children - // have been explored (i.e. if this is a downward or upward pass). - - do - { - spendResource(ResourceManager::Resource::PreprocessStep); - - // n is the input / original - // node is the output / result - Node node; - bool childrenPushed; - std::tie(n, node, childrenPushed) = worklist.top(); - worklist.pop(); - - // Working downwards - if (!childrenPushed) - { - Kind k = n.getKind(); - - // we can short circuit (variable) leaves - if (n.isVar()) - { - SmtEngine::DefinedFunctionMap* dfuns = d_smt.getDefinedFunctionMap(); - SmtEngine::DefinedFunctionMap::const_iterator i = dfuns->find(n); - if (i != dfuns->end()) - { - Node f = (*i).second.getFormula(); - // must expand its definition - Node fe = expandDefinitions(f, cache, expandOnly); - // replacement must be closed - if ((*i).second.getFormals().size() > 0) - { - result.push( - nm->mkNode(LAMBDA, - nm->mkNode(BOUND_VAR_LIST, (*i).second.getFormals()), - fe)); - continue; - } - // don't bother putting in the cache - result.push(fe); - continue; - } - // don't bother putting in the cache - result.push(n); - continue; - } - - // maybe it's in the cache - unordered_map<Node, Node, NodeHashFunction>::iterator cacheHit = - cache.find(n); - if (cacheHit != cache.end()) - { - TNode ret = (*cacheHit).second; - result.push(ret.isNull() ? n : ret); - continue; - } - - // otherwise expand it - bool doExpand = false; - if (k == APPLY_UF) - { - // Always do beta-reduction here. The reason is that there may be - // operators such as INTS_MODULUS in the body of the lambda that would - // otherwise be introduced by beta-reduction via the rewriter, but are - // not expanded here since the traversal in this function does not - // traverse the operators of nodes. Hence, we beta-reduce here to - // ensure terms in the body of the lambda are expanded during this - // call. - if (n.getOperator().getKind() == LAMBDA) - { - doExpand = true; - } - else - { - // We always check if this operator corresponds to a defined function. - doExpand = d_smt.isDefinedFunction(n.getOperator().toExpr()); - } - } - if (doExpand) - { - vector<Node> formals; - TNode fm; - if (n.getOperator().getKind() == LAMBDA) - { - TNode op = n.getOperator(); - // lambda - for (unsigned i = 0; i < op[0].getNumChildren(); i++) - { - formals.push_back(op[0][i]); - } - fm = op[1]; - } - else - { - // application of a user-defined symbol - TNode func = n.getOperator(); - SmtEngine::DefinedFunctionMap* dfuns = d_smt.getDefinedFunctionMap(); - SmtEngine::DefinedFunctionMap::const_iterator i = dfuns->find(func); - if (i == dfuns->end()) - { - throw TypeCheckingException( - n.toExpr(), - string("Undefined function: `") + func.toString() + "'"); - } - DefinedFunction def = (*i).second; - formals = def.getFormals(); - - if (Debug.isOn("expand")) - { - Debug("expand") << "found: " << n << endl; - Debug("expand") << " func: " << func << endl; - string name = func.getAttribute(expr::VarNameAttr()); - Debug("expand") << " : \"" << name << "\"" << endl; - } - if (Debug.isOn("expand")) - { - Debug("expand") << " defn: " << def.getFunction() << endl - << " ["; - if (formals.size() > 0) - { - copy(formals.begin(), - formals.end() - 1, - ostream_iterator<Node>(Debug("expand"), ", ")); - Debug("expand") << formals.back(); - } - Debug("expand") << "]" << endl - << " " << def.getFunction().getType() << endl - << " " << def.getFormula() << endl; - } - - fm = def.getFormula(); - } - - Node instance = fm.substitute(formals.begin(), - formals.end(), - n.begin(), - n.begin() + formals.size()); - Debug("expand") << "made : " << instance << endl; - - Node expanded = expandDefinitions(instance, cache, expandOnly); - cache[n] = (n == expanded ? Node::null() : expanded); - result.push(expanded); - continue; - } - else if (!expandOnly) - { - // do not do any theory stuff if expandOnly is true - - theory::Theory* t = d_smt.getTheoryEngine()->theoryOf(node); - - Assert(t != NULL); - TrustNode trn = t->expandDefinition(n); - node = trn.isNull() ? Node(n) : trn.getNode(); - } - - // the partial functions can fall through, in which case we still - // consider their children - worklist.push(std::make_tuple( - Node(n), node, true)); // Original and rewritten result - - for (size_t i = 0; i < node.getNumChildren(); ++i) - { - worklist.push( - std::make_tuple(node[i], - node[i], - false)); // Rewrite the children of the result only - } - } - else - { - // Working upwards - // Reconstruct the node from it's (now rewritten) children on the stack - - Debug("expand") << "cons : " << node << endl; - if (node.getNumChildren() > 0) - { - // cout << "cons : " << node << endl; - NodeBuilder<> nb(node.getKind()); - if (node.getMetaKind() == metakind::PARAMETERIZED) - { - Debug("expand") << "op : " << node.getOperator() << endl; - // cout << "op : " << node.getOperator() << endl; - nb << node.getOperator(); - } - for (size_t i = 0, nchild = node.getNumChildren(); i < nchild; ++i) - { - Assert(!result.empty()); - Node expanded = result.top(); - result.pop(); - // cout << "exchld : " << expanded << endl; - Debug("expand") << "exchld : " << expanded << endl; - nb << expanded; - } - node = nb; - } - // Only cache once all subterms are expanded - cache[n] = n == node ? Node::null() : node; - result.push(node); - } - } while (!worklist.empty()); - - AlwaysAssert(result.size() == 1); - - return result.top(); -} - void ProcessAssertions::collectSkolems( IteSkolemMap& iskMap, TNode n, diff --git a/src/smt/process_assertions.h b/src/smt/process_assertions.h index d260edf14..072603e7d 100644 --- a/src/smt/process_assertions.h +++ b/src/smt/process_assertions.h @@ -26,6 +26,7 @@ #include "preprocessing/preprocessing_pass.h" #include "preprocessing/preprocessing_pass_context.h" #include "smt/assertions.h" +#include "smt/expand_definitions.h" #include "smt/smt_engine_stats.h" #include "util/resource_manager.h" @@ -53,11 +54,11 @@ class ProcessAssertions { /** The types for the recursive function definitions */ typedef context::CDList<Node> NodeList; - typedef unordered_map<Node, Node, NodeHashFunction> NodeToNodeHashMap; typedef unordered_map<Node, bool, NodeHashFunction> NodeToBoolHashMap; public: ProcessAssertions(SmtEngine& smt, + ExpandDefs& exDefs, ResourceManager& rm, SmtEngineStatistics& stats); ~ProcessAssertions(); @@ -76,22 +77,12 @@ class ProcessAssertions * processing the assertions. */ bool apply(Assertions& as); - /** - * Expand definitions in term n. Return the expanded form of n. - * - * @param n The node to expand - * @param cache Cache of previous results - * @param expandOnly if true, then the expandDefinitions function of - * TheoryEngine is not called on subterms of n. - * @return The expanded term. - */ - Node expandDefinitions(TNode n, - NodeToNodeHashMap& cache, - bool expandOnly = false); private: /** Reference to the SMT engine */ SmtEngine& d_smt; + /** Reference to expand definitions module */ + ExpandDefs& d_exDefs; /** Reference to resource manager */ ResourceManager& d_resourceManager; /** Reference to the SMT stats */ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 2a0cde015..0f40db530 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -73,10 +73,9 @@ using namespace CVC4::theory; namespace CVC4 { -SmtEngine::SmtEngine(ExprManager* em, Options* optr) +SmtEngine::SmtEngine(NodeManager* nm, Options* optr) : d_state(new SmtEngineState(*this)), - d_exprManager(em), - d_nodeManager(d_exprManager->getNodeManager()), + d_nodeManager(nm), d_absValues(new AbstractValues(d_nodeManager)), d_asserts(new Assertions(getUserContext(), *d_absValues.get())), d_dumpm(new DumpManager(getUserContext())), @@ -244,7 +243,7 @@ void SmtEngine::finishInit() TheoryModel* tm = te->getModel(); if (tm != nullptr) { - d_model.reset(new Model(*this, tm)); + d_model.reset(new Model(tm)); // make the check models utility d_checkModels.reset(new CheckModels(*d_smtSolver.get())); } @@ -506,7 +505,7 @@ bool SmtEngine::isValidGetInfoFlag(const std::string& key) const if (key == "all-statistics" || key == "error-behavior" || key == "name" || key == "version" || key == "authors" || key == "status" || key == "reason-unknown" || key == "assertion-stack-levels" - || key == "all-options") + || key == "all-options" || key == "time") { return true; } @@ -525,14 +524,8 @@ CVC4::SExpr SmtEngine::getInfo(const std::string& key) const if (key == "all-statistics") { vector<SExpr> stats; - for (StatisticsRegistry::const_iterator i = - NodeManager::fromExprManager(d_exprManager) - ->getStatisticsRegistry() - ->begin(); - i - != NodeManager::fromExprManager(d_exprManager) - ->getStatisticsRegistry() - ->end(); + StatisticsRegistry* sr = d_nodeManager->getStatisticsRegistry(); + for (StatisticsRegistry::const_iterator i = sr->begin(); i != sr->end(); ++i) { vector<SExpr> v; @@ -578,6 +571,10 @@ CVC4::SExpr SmtEngine::getInfo(const std::string& key) const default: return SExpr(SExpr::Keyword("unknown")); } } + if (key == "time") + { + return SExpr(std::clock()); + } if (key == "reason-unknown") { Result status = d_state->getStatus(); @@ -1633,7 +1630,6 @@ void SmtEngine::pop() { void SmtEngine::reset() { SmtScope smts(this); - ExprManager *em = d_exprManager; Trace("smt") << "SMT reset()" << endl; if(Dump.isOn("benchmark")) { getOutputManager().getPrinter().toStreamCmdReset( @@ -1643,7 +1639,7 @@ void SmtEngine::reset() Options opts; opts.copyValues(d_originalOptions); this->~SmtEngine(); - new (this) SmtEngine(em, &opts); + new (this) SmtEngine(d_nodeManager, &opts); // Restore data set after creation notifyStartParsing(filename); } @@ -1709,10 +1705,7 @@ unsigned long SmtEngine::getResourceRemaining() const return d_resourceManager->getResourceRemaining(); } -NodeManager* SmtEngine::getNodeManager() const -{ - return d_exprManager->getNodeManager(); -} +NodeManager* SmtEngine::getNodeManager() const { return d_nodeManager; } Statistics SmtEngine::getStatistics() const { @@ -1729,20 +1722,15 @@ void SmtEngine::safeFlushStatistics(int fd) const { } void SmtEngine::setUserAttribute(const std::string& attr, - Expr expr, - const std::vector<Expr>& expr_values, + Node expr, + const std::vector<Node>& expr_values, const std::string& str_value) { SmtScope smts(this); finishInit(); - std::vector<Node> node_values; - for (std::size_t i = 0, n = expr_values.size(); i < n; i++) - { - node_values.push_back( expr_values[i].getNode() ); - } TheoryEngine* te = getTheoryEngine(); Assert(te != nullptr); - te->setUserAttribute(attr, expr.getNode(), node_values, str_value); + te->setUserAttribute(attr, expr, expr_values, str_value); } void SmtEngine::setOption(const std::string& key, const CVC4::SExpr& value) diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 1c83fa61f..a55428b55 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -21,13 +21,12 @@ #include <string> #include <vector> +#include <map> #include "base/modal_exception.h" #include "context/cdhashmap_forward.h" #include "context/cdhashset_forward.h" #include "context/cdlist_forward.h" -#include "expr/expr.h" -#include "expr/expr_manager.h" #include "options/options.h" #include "smt/logic_exception.h" #include "smt/output_manager.h" @@ -48,9 +47,10 @@ namespace CVC4 { template <bool ref_count> class NodeTemplate; typedef NodeTemplate<true> Node; typedef NodeTemplate<false> TNode; +class TypeNode; struct NodeHashFunction; -class SmtEngine; +class NodeManager; class DecisionEngine; class TheoryEngine; class ProofManager; @@ -58,6 +58,7 @@ class UnsatCore; class LogicRequest; class StatisticsRegistry; class Printer; +class ResourceManager; /* -------------------------------------------------------------------------- */ @@ -147,7 +148,7 @@ class CVC4_PUBLIC SmtEngine * If provided, optr is a pointer to a set of options that should initialize the values * of the options object owned by this class. */ - SmtEngine(ExprManager* em, Options* optr = nullptr); + SmtEngine(NodeManager* nm, Options* optr = nullptr); /** Destruct the SMT engine. */ ~SmtEngine(); @@ -691,7 +692,7 @@ class CVC4_PUBLIC SmtEngine /** * Completely reset the state of the solver, as though destroyed and * recreated. The result is as if newly constructed (so it still - * retains the same options structure and ExprManager). + * retains the same options structure and NodeManager). */ void reset(); @@ -785,9 +786,6 @@ class CVC4_PUBLIC SmtEngine */ unsigned long getResourceRemaining() const; - /** Permit access to the underlying ExprManager. */ - ExprManager* getExprManager() const { return d_exprManager; } - /** Permit access to the underlying NodeManager. */ NodeManager* getNodeManager() const; @@ -806,8 +804,8 @@ class CVC4_PUBLIC SmtEngine * In SMT-LIBv2 this is done via the syntax (! expr :attr) */ void setUserAttribute(const std::string& attr, - Expr expr, - const std::vector<Expr>& expr_values, + Node expr, + const std::vector<Node>& expr_values, const std::string& str_value); /** Get the options object (const and non-const versions) */ @@ -1013,9 +1011,7 @@ class CVC4_PUBLIC SmtEngine */ std::unique_ptr<smt::SmtEngineState> d_state; - /** Our expression manager */ - ExprManager* d_exprManager; - /** Our internal expression/node manager */ + /** Our internal node manager */ NodeManager* d_nodeManager; /** Abstract values */ std::unique_ptr<smt::AbstractValues> d_absValues; |