diff options
-rw-r--r-- | src/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/printer/ast/ast_printer.cpp | 3 | ||||
-rw-r--r-- | src/printer/ast/ast_printer.h | 2 | ||||
-rw-r--r-- | src/printer/cvc/cvc_printer.cpp | 30 | ||||
-rw-r--r-- | src/printer/cvc/cvc_printer.h | 2 | ||||
-rw-r--r-- | src/printer/printer.cpp | 8 | ||||
-rw-r--r-- | src/printer/printer.h | 6 | ||||
-rw-r--r-- | src/printer/smt2/smt2_printer.cpp | 31 | ||||
-rw-r--r-- | src/printer/smt2/smt2_printer.h | 4 | ||||
-rw-r--r-- | src/printer/tptp/tptp_printer.cpp | 13 | ||||
-rw-r--r-- | src/printer/tptp/tptp_printer.h | 2 | ||||
-rw-r--r-- | src/smt/dump.h | 16 | ||||
-rw-r--r-- | src/smt/dump_manager.cpp | 16 | ||||
-rw-r--r-- | src/smt/dump_manager.h | 14 | ||||
-rw-r--r-- | src/smt/listeners.cpp | 19 | ||||
-rw-r--r-- | src/smt/model.cpp | 4 | ||||
-rw-r--r-- | src/smt/model.h | 4 | ||||
-rw-r--r-- | src/smt/node_command.cpp | 180 | ||||
-rw-r--r-- | src/smt/node_command.h | 157 | ||||
-rw-r--r-- | src/smt/smt_engine.cpp | 13 |
20 files changed, 448 insertions, 78 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0ad9526a5..692ae09ac 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -210,6 +210,8 @@ libcvc4_add_sources( smt/model_core_builder.h smt/model_blocker.cpp smt/model_blocker.h + smt/node_command.cpp + smt/node_command.h smt/options_manager.cpp smt/options_manager.h smt/quant_elim_solver.cpp diff --git a/src/printer/ast/ast_printer.cpp b/src/printer/ast/ast_printer.cpp index d4f28c186..f235721f1 100644 --- a/src/printer/ast/ast_printer.cpp +++ b/src/printer/ast/ast_printer.cpp @@ -26,6 +26,7 @@ #include "options/language.h" // for LANG_AST #include "printer/dagification_visitor.h" #include "smt/command.h" +#include "smt/node_command.h" #include "theory/substitutions.h" using namespace std; @@ -156,7 +157,7 @@ void AstPrinter::toStream(std::ostream& out, const Model& m) const void AstPrinter::toStream(std::ostream& out, const Model& m, - const Command* c) const + const NodeCommand* c) const { // shouldn't be called; only the non-Command* version above should be Unreachable(); diff --git a/src/printer/ast/ast_printer.h b/src/printer/ast/ast_printer.h index 17e052037..969240930 100644 --- a/src/printer/ast/ast_printer.h +++ b/src/printer/ast/ast_printer.h @@ -175,7 +175,7 @@ class AstPrinter : public CVC4::Printer void toStream(std::ostream& out, TNode n, int toDepth, bool types) const; void toStream(std::ostream& out, const Model& m, - const Command* c) const override; + const NodeCommand* c) const override; }; /* class AstPrinter */ } // namespace ast diff --git a/src/printer/cvc/cvc_printer.cpp b/src/printer/cvc/cvc_printer.cpp index 89b516511..b94977cfe 100644 --- a/src/printer/cvc/cvc_printer.cpp +++ b/src/printer/cvc/cvc_printer.cpp @@ -33,6 +33,7 @@ #include "options/smt_options.h" #include "printer/dagification_visitor.h" #include "smt/command.h" +#include "smt/node_command.h" #include "smt/smt_engine.h" #include "theory/arrays/theory_arrays_rewriter.h" #include "theory/substitutions.h" @@ -1059,11 +1060,11 @@ void CvcPrinter::toStream(std::ostream& out, const CommandStatus* s) const namespace { -void DeclareTypeCommandToStream(std::ostream& out, - const theory::TheoryModel& model, - const DeclareTypeCommand& command) +void DeclareTypeNodeCommandToStream(std::ostream& out, + const theory::TheoryModel& model, + const DeclareTypeNodeCommand& command) { - TypeNode type_node = TypeNode::fromType(command.getType()); + TypeNode type_node = command.getType(); const std::vector<Node>* type_reps = model.getRepSet()->getTypeRepsOrNull(type_node); if (options::modelUninterpDtEnum() && type_node.isSort() @@ -1104,11 +1105,12 @@ void DeclareTypeCommandToStream(std::ostream& out, } } -void DeclareFunctionCommandToStream(std::ostream& out, - const theory::TheoryModel& model, - const DeclareFunctionCommand& command) +void DeclareFunctionNodeCommandToStream( + std::ostream& out, + const theory::TheoryModel& model, + const DeclareFunctionNodeCommand& command) { - Node n = Node::fromExpr(command.getFunction()); + Node n = command.getFunction(); if (n.getKind() == kind::SKOLEM) { // don't print out internal stuff @@ -1172,23 +1174,23 @@ void CvcPrinter::toStream(std::ostream& out, const Model& m) const void CvcPrinter::toStream(std::ostream& out, const Model& model, - const Command* command) const + const NodeCommand* command) const { const auto* theory_model = dynamic_cast<const theory::TheoryModel*>(&model); AlwaysAssert(theory_model != nullptr); if (const auto* declare_type_command = - dynamic_cast<const DeclareTypeCommand*>(command)) + dynamic_cast<const DeclareTypeNodeCommand*>(command)) { - DeclareTypeCommandToStream(out, *theory_model, *declare_type_command); + DeclareTypeNodeCommandToStream(out, *theory_model, *declare_type_command); } else if (const auto* dfc = - dynamic_cast<const DeclareFunctionCommand*>(command)) + dynamic_cast<const DeclareFunctionNodeCommand*>(command)) { - DeclareFunctionCommandToStream(out, *theory_model, *dfc); + DeclareFunctionNodeCommandToStream(out, *theory_model, *dfc); } else { - out << command << std::endl; + out << *command << std::endl; } } diff --git a/src/printer/cvc/cvc_printer.h b/src/printer/cvc/cvc_printer.h index 0fd3d3a49..3c61fb74f 100644 --- a/src/printer/cvc/cvc_printer.h +++ b/src/printer/cvc/cvc_printer.h @@ -177,7 +177,7 @@ class CvcPrinter : public CVC4::Printer std::ostream& out, TNode n, int toDepth, bool types, bool bracket) const; void toStream(std::ostream& out, const Model& m, - const Command* c) const override; + const NodeCommand* c) const override; bool d_cvc3Mode; }; /* class CvcPrinter */ diff --git a/src/printer/printer.cpp b/src/printer/printer.cpp index 0e7550518..d13fc55f1 100644 --- a/src/printer/printer.cpp +++ b/src/printer/printer.cpp @@ -23,6 +23,7 @@ #include "printer/cvc/cvc_printer.h" #include "printer/smt2/smt2_printer.h" #include "printer/tptp/tptp_printer.h" +#include "smt/node_command.h" using namespace std; @@ -72,9 +73,10 @@ unique_ptr<Printer> Printer::makePrinter(OutputLanguage lang) void Printer::toStream(std::ostream& out, const Model& m) const { for(size_t i = 0; i < m.getNumCommands(); ++i) { - const Command* cmd = m.getCommand(i); - const DeclareFunctionCommand* dfc = dynamic_cast<const DeclareFunctionCommand*>(cmd); - if (dfc != NULL && !m.isModelCoreSymbol(dfc->getFunction())) + const NodeCommand* cmd = m.getCommand(i); + const DeclareFunctionNodeCommand* dfc = + dynamic_cast<const DeclareFunctionNodeCommand*>(cmd); + if (dfc != NULL && !m.isModelCoreSymbol(dfc->getFunction().toExpr())) { continue; } diff --git a/src/printer/printer.h b/src/printer/printer.h index 3b737ec5f..8c95e3e9b 100644 --- a/src/printer/printer.h +++ b/src/printer/printer.h @@ -30,6 +30,8 @@ namespace CVC4 { +class NodeCommand; + class Printer { public: @@ -271,13 +273,13 @@ class Printer /** write model response to command */ virtual void toStream(std::ostream& out, const Model& m, - const Command* c) const = 0; + const NodeCommand* c) const = 0; /** write model response to command using another language printer */ void toStreamUsing(OutputLanguage lang, std::ostream& out, const Model& m, - const Command* c) const + const NodeCommand* c) const { getPrinter(lang)->toStream(out, m, c); } diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 3d76c81dc..da0423956 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -31,11 +31,12 @@ #include "options/printer_options.h" #include "options/smt_options.h" #include "printer/dagification_visitor.h" +#include "smt/node_command.h" #include "smt/smt_engine.h" #include "smt_util/boolean_simplification.h" #include "theory/arrays/theory_arrays_rewriter.h" -#include "theory/quantifiers/quantifiers_attributes.h" #include "theory/datatypes/sygus_datatype_utils.h" +#include "theory/quantifiers/quantifiers_attributes.h" #include "theory/substitutions.h" #include "theory/theory_model.h" #include "util/smt2_quote_string.h" @@ -1331,23 +1332,23 @@ void Smt2Printer::toStream(std::ostream& out, const Model& m) const void Smt2Printer::toStream(std::ostream& out, const Model& model, - const Command* command) const + const NodeCommand* command) const { const theory::TheoryModel* theory_model = dynamic_cast<const theory::TheoryModel*>(&model); AlwaysAssert(theory_model != nullptr); - if (const DeclareTypeCommand* dtc = - dynamic_cast<const DeclareTypeCommand*>(command)) + if (const DeclareTypeNodeCommand* dtc = + dynamic_cast<const DeclareTypeNodeCommand*>(command)) { // print out the DeclareTypeCommand - Type t = (*dtc).getType(); - if (!t.isSort()) + TypeNode tn = dtc->getType(); + if (!tn.isSort()) { out << (*dtc) << endl; } else { - std::vector<Expr> elements = theory_model->getDomainElements(t); + std::vector<Expr> elements = theory_model->getDomainElements(tn.toType()); if (options::modelUninterpDtEnum()) { if (isVariant_2_6(d_variant)) @@ -1367,7 +1368,7 @@ void Smt2Printer::toStream(std::ostream& out, else { // print the cardinality - out << "; cardinality of " << t << " is " << elements.size() << endl; + out << "; cardinality of " << tn << " is " << elements.size() << endl; out << (*dtc) << endl; // print the representatives for (const Expr& type_ref : elements) @@ -1375,7 +1376,7 @@ void Smt2Printer::toStream(std::ostream& out, Node trn = Node::fromExpr(type_ref); if (trn.isVar()) { - out << "(declare-fun " << quoteSymbol(trn) << " () " << t << ")" + out << "(declare-fun " << quoteSymbol(trn) << " () " << tn << ")" << endl; } else @@ -1386,11 +1387,11 @@ void Smt2Printer::toStream(std::ostream& out, } } } - else if (const DeclareFunctionCommand* dfc = - dynamic_cast<const DeclareFunctionCommand*>(command)) + else if (const DeclareFunctionNodeCommand* dfc = + dynamic_cast<const DeclareFunctionNodeCommand*>(command)) { // print out the DeclareFunctionCommand - Node n = Node::fromExpr((*dfc).getFunction()); + Node n = dfc->getFunction(); if ((*dfc).getPrintInModelSetByUser()) { if (!(*dfc).getPrintInModel()) @@ -1432,10 +1433,10 @@ void Smt2Printer::toStream(std::ostream& out, out << ")" << endl; } } - else if (const DatatypeDeclarationCommand* datatype_declaration_command = - dynamic_cast<const DatatypeDeclarationCommand*>(command)) + else if (const DeclareDatatypeNodeCommand* declare_datatype_command = + dynamic_cast<const DeclareDatatypeNodeCommand*>(command)) { - out << datatype_declaration_command; + out << *declare_datatype_command; } else { diff --git a/src/printer/smt2/smt2_printer.h b/src/printer/smt2/smt2_printer.h index 6b57823a4..0cf06dd6b 100644 --- a/src/printer/smt2/smt2_printer.h +++ b/src/printer/smt2/smt2_printer.h @@ -19,8 +19,6 @@ #ifndef CVC4__PRINTER__SMT2_PRINTER_H #define CVC4__PRINTER__SMT2_PRINTER_H -#include <iostream> - #include "printer/printer.h" namespace CVC4 { @@ -234,7 +232,7 @@ class Smt2Printer : public CVC4::Printer std::ostream& out, TNode n, int toDepth, bool types, TypeNode nt) const; void toStream(std::ostream& out, const Model& m, - const Command* c) const override; + const NodeCommand* c) const override; void toStream(std::ostream& out, const SExpr& sexpr) const; void toStream(std::ostream& out, const DType& dt) const; diff --git a/src/printer/tptp/tptp_printer.cpp b/src/printer/tptp/tptp_printer.cpp index c4623f76a..d25666d70 100644 --- a/src/printer/tptp/tptp_printer.cpp +++ b/src/printer/tptp/tptp_printer.cpp @@ -20,12 +20,13 @@ #include <typeinfo> #include <vector> -#include "expr/expr.h" // for ExprSetDepth etc.. -#include "expr/node_manager.h" // for VarNameAttr -#include "options/language.h" // for LANG_AST -#include "options/smt_options.h" // for unsat cores -#include "smt/smt_engine.h" +#include "expr/expr.h" // for ExprSetDepth etc.. +#include "expr/node_manager.h" // for VarNameAttr +#include "options/language.h" // for LANG_AST +#include "options/smt_options.h" // for unsat cores #include "smt/command.h" +#include "smt/node_command.h" +#include "smt/smt_engine.h" using namespace std; @@ -59,7 +60,7 @@ void TptpPrinter::toStream(std::ostream& out, const Model& m) const void TptpPrinter::toStream(std::ostream& out, const Model& m, - const Command* c) const + const NodeCommand* c) const { // shouldn't be called; only the non-Command* version above should be Unreachable(); diff --git a/src/printer/tptp/tptp_printer.h b/src/printer/tptp/tptp_printer.h index 6682b495e..9377a8895 100644 --- a/src/printer/tptp/tptp_printer.h +++ b/src/printer/tptp/tptp_printer.h @@ -47,7 +47,7 @@ class TptpPrinter : public CVC4::Printer private: void toStream(std::ostream& out, const Model& m, - const Command* c) const override; + const NodeCommand* c) const override; }; /* class TptpPrinter */ diff --git a/src/smt/dump.h b/src/smt/dump.h index 050935422..4c0efeb6e 100644 --- a/src/smt/dump.h +++ b/src/smt/dump.h @@ -21,6 +21,7 @@ #include "base/output.h" #include "smt/command.h" +#include "smt/node_command.h" namespace CVC4 { @@ -40,6 +41,20 @@ class CVC4_PUBLIC CVC4dumpstream return *this; } + /** A convenience function for dumping internal commands. + * + * Since Commands are now part of the public API, internal code should use + * NodeCommands and this function (instead of the one above) to dump them. + */ + CVC4dumpstream& operator<<(const NodeCommand& nc) + { + if (d_os != nullptr) + { + (*d_os) << nc << std::endl; + } + return *this; + } + private: std::ostream* d_os; }; /* class CVC4dumpstream */ @@ -56,6 +71,7 @@ class CVC4_PUBLIC CVC4dumpstream CVC4dumpstream() {} CVC4dumpstream(std::ostream& os) {} CVC4dumpstream& operator<<(const Command& c) { return *this; } + CVC4dumpstream& operator<<(const NodeCommand& nc) { return *this; } }; /* class CVC4dumpstream */ #endif /* CVC4_DUMPING && !CVC4_MUZZLE */ diff --git a/src/smt/dump_manager.cpp b/src/smt/dump_manager.cpp index d5fd65c4c..033be405f 100644 --- a/src/smt/dump_manager.cpp +++ b/src/smt/dump_manager.cpp @@ -51,7 +51,7 @@ void DumpManager::finishInit() void DumpManager::resetAssertions() { d_modelGlobalCommands.clear(); } -void DumpManager::addToModelCommandAndDump(const Command& c, +void DumpManager::addToModelCommandAndDump(const NodeCommand& c, uint32_t flags, bool userVisible, const char* dumpTag) @@ -70,14 +70,14 @@ void DumpManager::addToModelCommandAndDump(const Command& c, { if (flags & ExprManager::VAR_FLAG_GLOBAL) { - d_modelGlobalCommands.push_back(std::unique_ptr<Command>(c.clone())); + d_modelGlobalCommands.push_back(std::unique_ptr<NodeCommand>(c.clone())); } else { - Command* cc = c.clone(); + NodeCommand* cc = c.clone(); d_modelCommands.push_back(cc); // also remember for memory management purposes - d_modelCommandsAlloc.push_back(std::unique_ptr<Command>(cc)); + d_modelCommandsAlloc.push_back(std::unique_ptr<NodeCommand>(cc)); } } if (Dump.isOn(dumpTag)) @@ -88,7 +88,7 @@ void DumpManager::addToModelCommandAndDump(const Command& c, } else { - d_dumpCommands.push_back(std::unique_ptr<Command>(c.clone())); + d_dumpCommands.push_back(std::unique_ptr<NodeCommand>(c.clone())); } } } @@ -96,7 +96,7 @@ void DumpManager::addToModelCommandAndDump(const Command& c, void DumpManager::setPrintFuncInModel(Node f, bool p) { Trace("setp-model") << "Set printInModel " << f << " to " << p << std::endl; - for (std::unique_ptr<Command>& c : d_modelGlobalCommands) + for (std::unique_ptr<NodeCommand>& c : d_modelGlobalCommands) { DeclareFunctionCommand* dfc = dynamic_cast<DeclareFunctionCommand*>(c.get()); @@ -109,7 +109,7 @@ void DumpManager::setPrintFuncInModel(Node f, bool p) } } } - for (Command* c : d_modelCommands) + for (NodeCommand* c : d_modelCommands) { DeclareFunctionCommand* dfc = dynamic_cast<DeclareFunctionCommand*>(c); if (dfc != NULL) @@ -128,7 +128,7 @@ size_t DumpManager::getNumModelCommands() const return d_modelCommands.size() + d_modelGlobalCommands.size(); } -const Command* DumpManager::getModelCommand(size_t i) const +const NodeCommand* DumpManager::getModelCommand(size_t i) const { Assert(i < getNumModelCommands()); // index the global commands first, then the locals diff --git a/src/smt/dump_manager.h b/src/smt/dump_manager.h index 6f2ee37a1..2ce0570e4 100644 --- a/src/smt/dump_manager.h +++ b/src/smt/dump_manager.h @@ -22,7 +22,7 @@ #include "context/cdlist.h" #include "expr/node.h" -#include "smt/command.h" +#include "smt/node_command.h" namespace CVC4 { namespace smt { @@ -36,7 +36,7 @@ namespace smt { */ class DumpManager { - typedef context::CDList<Command*> CommandList; + typedef context::CDList<NodeCommand*> CommandList; public: DumpManager(context::UserContext* u); @@ -54,7 +54,7 @@ class DumpManager * Add to Model command. This is used for recording a command * that should be reported during a get-model call. */ - void addToModelCommandAndDump(const Command& c, + void addToModelCommandAndDump(const NodeCommand& c, uint32_t flags = 0, bool userVisible = true, const char* dumpTag = "declarations"); @@ -66,7 +66,7 @@ class DumpManager /** get number of commands to report in a model */ size_t getNumModelCommands() const; /** get model command at index i */ - const Command* getModelCommand(size_t i) const; + const NodeCommand* getModelCommand(size_t i) const; private: /** Fully inited */ @@ -77,7 +77,7 @@ class DumpManager * regardless of push/pop). Only maintained if produce-models option * is on. */ - std::vector<std::unique_ptr<Command>> d_modelGlobalCommands; + std::vector<std::unique_ptr<NodeCommand>> d_modelGlobalCommands; /** * A list of commands that should be in the Model locally (i.e., @@ -89,7 +89,7 @@ class DumpManager * A list of model commands allocated to d_modelCommands at any time. This * is maintained for memory management purposes. */ - std::vector<std::unique_ptr<Command>> d_modelCommandsAlloc; + std::vector<std::unique_ptr<NodeCommand>> d_modelCommandsAlloc; /** * A vector of declaration commands waiting to be dumped out. @@ -97,7 +97,7 @@ class DumpManager * This ensures the declarations come after the set-logic and * any necessary set-option commands are dumped. */ - std::vector<std::unique_ptr<Command>> d_dumpCommands; + std::vector<std::unique_ptr<NodeCommand>> d_dumpCommands; }; } // namespace smt diff --git a/src/smt/listeners.cpp b/src/smt/listeners.cpp index 539d6ba2f..52ddcf156 100644 --- a/src/smt/listeners.cpp +++ b/src/smt/listeners.cpp @@ -18,7 +18,7 @@ #include "expr/expr.h" #include "expr/node_manager_attributes.h" #include "options/smt_options.h" -#include "smt/command.h" +#include "smt/node_command.h" #include "smt/dump.h" #include "smt/dump_manager.h" #include "smt/smt_engine.h" @@ -40,7 +40,7 @@ SmtNodeManagerListener::SmtNodeManagerListener(DumpManager& dm) : d_dm(dm) {} void SmtNodeManagerListener::nmNotifyNewSort(TypeNode tn, uint32_t flags) { - DeclareTypeCommand c(tn.getAttribute(expr::VarNameAttr()), 0, tn.toType()); + DeclareTypeNodeCommand c(tn.getAttribute(expr::VarNameAttr()), 0, tn); if ((flags & ExprManager::SORT_FLAG_PLACEHOLDER) == 0) { d_dm.addToModelCommandAndDump(c, flags); @@ -50,9 +50,9 @@ void SmtNodeManagerListener::nmNotifyNewSort(TypeNode tn, uint32_t flags) void SmtNodeManagerListener::nmNotifyNewSortConstructor(TypeNode tn, uint32_t flags) { - DeclareTypeCommand c(tn.getAttribute(expr::VarNameAttr()), - tn.getAttribute(expr::SortArityAttr()), - tn.toType()); + DeclareTypeNodeCommand c(tn.getAttribute(expr::VarNameAttr()), + tn.getAttribute(expr::SortArityAttr()), + tn); if ((flags & ExprManager::SORT_FLAG_PLACEHOLDER) == 0) { d_dm.addToModelCommandAndDump(c); @@ -68,17 +68,16 @@ void SmtNodeManagerListener::nmNotifyNewDatatypes( for (const TypeNode& dt : dtts) { Assert(dt.isDatatype()); - types.push_back(dt.toType()); } - DatatypeDeclarationCommand c(types); + DeclareDatatypeNodeCommand c(dtts); d_dm.addToModelCommandAndDump(c); } } void SmtNodeManagerListener::nmNotifyNewVar(TNode n, uint32_t flags) { - DeclareFunctionCommand c( - n.getAttribute(expr::VarNameAttr()), n.toExpr(), n.getType().toType()); + DeclareFunctionNodeCommand c( + n.getAttribute(expr::VarNameAttr()), n, n.getType()); if ((flags & ExprManager::VAR_FLAG_DEFINED) == 0) { d_dm.addToModelCommandAndDump(c, flags); @@ -90,7 +89,7 @@ void SmtNodeManagerListener::nmNotifyNewSkolem(TNode n, uint32_t flags) { std::string id = n.getAttribute(expr::VarNameAttr()); - DeclareFunctionCommand c(id, n.toExpr(), n.getType().toType()); + DeclareFunctionNodeCommand c(id, n, n.getType()); if (Dump.isOn("skolems") && comment != "") { Dump("skolems") << CommentCommand(id + " is " + comment); diff --git a/src/smt/model.cpp b/src/smt/model.cpp index 7924698ff..a23b885ff 100644 --- a/src/smt/model.cpp +++ b/src/smt/model.cpp @@ -19,8 +19,8 @@ #include "expr/expr_iomanip.h" #include "options/base_options.h" #include "printer/printer.h" -#include "smt/command.h" #include "smt/dump_manager.h" +#include "smt/node_command.h" #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" @@ -42,7 +42,7 @@ size_t Model::getNumCommands() const return d_smt.getDumpManager()->getNumModelCommands(); } -const Command* Model::getCommand(size_t i) const +const NodeCommand* Model::getCommand(size_t i) const { return d_smt.getDumpManager()->getModelCommand(i); } diff --git a/src/smt/model.h b/src/smt/model.h index 8f4409b07..4c28704c3 100644 --- a/src/smt/model.h +++ b/src/smt/model.h @@ -25,7 +25,7 @@ namespace CVC4 { -class Command; +class NodeCommand; class SmtEngine; class Model; @@ -48,7 +48,7 @@ class Model { /** get number of commands to report */ size_t getNumCommands() const; /** get command */ - const Command* getCommand(size_t i) const; + const NodeCommand* getCommand(size_t i) const; /** get the smt engine that this model is hooked up to */ SmtEngine* getSmtEngine() { return &d_smt; } /** get the smt engine (as a pointer-to-const) that this model is hooked up to */ diff --git a/src/smt/node_command.cpp b/src/smt/node_command.cpp new file mode 100644 index 000000000..265b35b3e --- /dev/null +++ b/src/smt/node_command.cpp @@ -0,0 +1,180 @@ +/********************* */ +/*! \file node_command.cpp + ** \verbatim + ** Top contributors (to current version): + ** Abdalrhman Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of NodeCommand functions. + ** + ** Implementation of NodeCommand functions. + **/ + +#include "smt/node_command.h" + +#include "printer/printer.h" + +namespace CVC4 { + +/* -------------------------------------------------------------------------- */ +/* class NodeCommand */ +/* -------------------------------------------------------------------------- */ + +NodeCommand::~NodeCommand() {} + +std::string NodeCommand::toString() const +{ + std::stringstream ss; + toStream(ss); + return ss.str(); +} + +std::ostream& operator<<(std::ostream& out, const NodeCommand& nc) +{ + nc.toStream(out, + Node::setdepth::getDepth(out), + Node::printtypes::getPrintTypes(out), + Node::dag::getDag(out), + Node::setlanguage::getLanguage(out)); + return out; +} + +/* -------------------------------------------------------------------------- */ +/* class DeclareFunctionNodeCommand */ +/* -------------------------------------------------------------------------- */ + +DeclareFunctionNodeCommand::DeclareFunctionNodeCommand(const std::string& id, + Node expr, + TypeNode type) + : d_id(id), + d_fun(expr), + d_type(type), + d_printInModel(true), + d_printInModelSetByUser(false) +{ +} + +void DeclareFunctionNodeCommand::toStream(std::ostream& out, + int toDepth, + bool types, + size_t dag, + OutputLanguage language) const +{ + Printer::getPrinter(language)->toStreamCmdDeclareFunction(out, d_id, d_type); +} + +NodeCommand* DeclareFunctionNodeCommand::clone() const +{ + return new DeclareFunctionNodeCommand(d_id, d_fun, d_type); +} + +const Node& DeclareFunctionNodeCommand::getFunction() const { return d_fun; } + +bool DeclareFunctionNodeCommand::getPrintInModel() const +{ + return d_printInModel; +} + +bool DeclareFunctionNodeCommand::getPrintInModelSetByUser() const +{ + return d_printInModelSetByUser; +} + +void DeclareFunctionNodeCommand::setPrintInModel(bool p) +{ + d_printInModel = p; + d_printInModelSetByUser = true; +} + +/* -------------------------------------------------------------------------- */ +/* class DeclareTypeNodeCommand */ +/* -------------------------------------------------------------------------- */ + +DeclareTypeNodeCommand::DeclareTypeNodeCommand(const std::string& id, + size_t arity, + TypeNode type) + : d_id(id), d_arity(arity), d_type(type) +{ +} + +void DeclareTypeNodeCommand::toStream(std::ostream& out, + int toDepth, + bool types, + size_t dag, + OutputLanguage language) const +{ + Printer::getPrinter(language)->toStreamCmdDeclareType( + out, d_id, d_arity, d_type); +} + +NodeCommand* DeclareTypeNodeCommand::clone() const +{ + return new DeclareTypeNodeCommand(d_id, d_arity, d_type); +} + +const std::string DeclareTypeNodeCommand::getSymbol() const { return d_id; } + +const TypeNode& DeclareTypeNodeCommand::getType() const { return d_type; } + +/* -------------------------------------------------------------------------- */ +/* class DeclareDatatypeNodeCommand */ +/* -------------------------------------------------------------------------- */ + +DeclareDatatypeNodeCommand::DeclareDatatypeNodeCommand( + const std::vector<TypeNode>& datatypes) + : d_datatypes(datatypes) +{ +} + +void DeclareDatatypeNodeCommand::toStream(std::ostream& out, + int toDepth, + bool types, + size_t dag, + OutputLanguage language) const +{ + Printer::getPrinter(language)->toStreamCmdDatatypeDeclaration(out, + d_datatypes); +} + +NodeCommand* DeclareDatatypeNodeCommand::clone() const +{ + return new DeclareDatatypeNodeCommand(d_datatypes); +} + +/* -------------------------------------------------------------------------- */ +/* class DefineFunctionNodeCommand */ +/* -------------------------------------------------------------------------- */ + +DefineFunctionNodeCommand::DefineFunctionNodeCommand( + const std::string& id, + Node fun, + const std::vector<Node>& formals, + Node formula) + : d_id(id), d_fun(fun), d_formals(formals), d_formula(formula) +{ +} + +void DefineFunctionNodeCommand::toStream(std::ostream& out, + int toDepth, + bool types, + size_t dag, + OutputLanguage language) const +{ + Printer::getPrinter(language)->toStreamCmdDefineFunction( + out, + d_fun.toString(), + d_formals, + d_fun.getType().getRangeType(), + d_formula); +} + +NodeCommand* DefineFunctionNodeCommand::clone() const +{ + return new DefineFunctionNodeCommand(d_id, d_fun, d_formals, d_formula); +} + +} // namespace CVC4 diff --git a/src/smt/node_command.h b/src/smt/node_command.h new file mode 100644 index 000000000..2ca166bb6 --- /dev/null +++ b/src/smt/node_command.h @@ -0,0 +1,157 @@ +/********************* */ +/*! \file node_command.h + ** \verbatim + ** Top contributors (to current version): + ** Abdalrhman Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Datastructures used for printing commands internally. + ** + ** Datastructures used for printing commands internally. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__SMT__NODE_COMMAND_H +#define CVC4__SMT__NODE_COMMAND_H + +#include <string> + +#include "expr/node.h" +#include "expr/type_node.h" +#include "options/language.h" + +namespace CVC4 { + +/** + * A node version of Command. DO NOT use this version unless there is a need + * to buffer commands for later use (e.g., printing models). + */ +class NodeCommand +{ + public: + /** Destructor */ + virtual ~NodeCommand(); + + /** Print this NodeCommand to output stream */ + virtual void toStream( + std::ostream& out, + int toDepth = -1, + bool types = false, + size_t dag = 1, + OutputLanguage language = language::output::LANG_AUTO) const = 0; + + /** Get a string representation of this NodeCommand */ + std::string toString() const; + + /** Clone this NodeCommand (make a shallow copy). */ + virtual NodeCommand* clone() const = 0; +}; + +std::ostream& operator<<(std::ostream& out, const NodeCommand& nc); + +/** + * Declare n-ary function symbol. + * SMT-LIB: ( declare-fun <id> ( <type.getArgTypes()> ) <type.getRangeType()> ) + */ +class DeclareFunctionNodeCommand : public NodeCommand +{ + public: + DeclareFunctionNodeCommand(const std::string& id, Node fun, TypeNode type); + void toStream( + std::ostream& out, + int toDepth = -1, + bool types = false, + size_t dag = 1, + OutputLanguage language = language::output::LANG_AUTO) const override; + NodeCommand* clone() const override; + const Node& getFunction() const; + bool getPrintInModel() const; + bool getPrintInModelSetByUser() const; + void setPrintInModel(bool p); + + private: + std::string d_id; + Node d_fun; + TypeNode d_type; + bool d_printInModel; + bool d_printInModelSetByUser; +}; + +/** + * Create datatype sort. + * SMT-LIB: ( declare-datatypes ( <datatype decls>{n+1} ) ( <datatypes>{n+1} ) ) + */ +class DeclareDatatypeNodeCommand : public NodeCommand +{ + public: + DeclareDatatypeNodeCommand(const std::vector<TypeNode>& datatypes); + void toStream( + std::ostream& out, + int toDepth = -1, + bool types = false, + size_t dag = 1, + OutputLanguage language = language::output::LANG_AUTO) const override; + NodeCommand* clone() const override; + + private: + std::vector<TypeNode> d_datatypes; +}; + +/** + * Declare uninterpreted sort. + * SMT-LIB: ( declare-sort <id> <arity> ) + */ +class DeclareTypeNodeCommand : public NodeCommand +{ + public: + DeclareTypeNodeCommand(const std::string& id, size_t arity, TypeNode type); + void toStream( + std::ostream& out, + int toDepth = -1, + bool types = false, + size_t dag = 1, + OutputLanguage language = language::output::LANG_AUTO) const override; + NodeCommand* clone() const override; + const std::string getSymbol() const; + const TypeNode& getType() const; + + private: + std::string d_id; + size_t d_arity; + TypeNode d_type; +}; + +/** + * Define n-ary function. + * SMT-LIB: ( define-fun <id> ( <formals> ) <fun.getType()> <formula> ) + */ +class DefineFunctionNodeCommand : public NodeCommand +{ + public: + DefineFunctionNodeCommand(const std::string& id, + Node fun, + const std::vector<Node>& formals, + Node formula); + void toStream( + std::ostream& out, + int toDepth = -1, + bool types = false, + size_t dag = 1, + OutputLanguage language = language::output::LANG_AUTO) const override; + NodeCommand* clone() const override; + + private: + std::string d_id; + Node d_fun; + std::vector<Node> d_formals; + Node d_formula; +}; + +} // namespace CVC4 + +#endif /* CVC4__SMT__NODE_COMMAND_H */ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 531dbff0d..81d4f594d 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -688,9 +688,18 @@ void SmtEngine::defineFunction(Expr func, ss << language::SetLanguage( language::SetLanguage::getLanguage(Dump.getStream())) << func; - DefineFunctionCommand c(ss.str(), func, formals, formula, global); + std::vector<Node> nFormals; + nFormals.reserve(formals.size()); + + for (const Expr& formal : formals) + { + nFormals.push_back(formal.getNode()); + } + + DefineFunctionNodeCommand nc( + ss.str(), func.getNode(), nFormals, formula.getNode()); d_dumpm->addToModelCommandAndDump( - c, ExprManager::VAR_FLAG_DEFINED, true, "declarations"); + nc, ExprManager::VAR_FLAG_DEFINED, true, "declarations"); // type check body debugCheckFunctionBody(formula, formals, func); |