From 7c249b3efdeeb51fd3dfc2571bc529c55880cf5c Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 16 Oct 2020 13:32:42 -0500 Subject: Refactor SMT-level model object (#5277) This refactors the SMT-level model object so that it is a wrapper around TheoryModel instead of a base class. This inheritance was unnecessary. Moreover, it removes the virtual base models of the SMT-level model which were based on Expr. Now the interface is more minimal and in terms of Node only. This PR further simplifies a few places in the code that interface with the SmtEngine with things related to models. --- src/smt/command.h | 5 ++- src/smt/model.cpp | 27 +++++++++--- src/smt/model.h | 93 ++++++++++++++++-------------------------- src/smt/model_blocker.cpp | 4 +- src/smt/model_core_builder.cpp | 6 +-- src/smt/model_core_builder.h | 4 +- src/smt/smt_engine.cpp | 65 ++++++++++++++++++----------- src/smt/smt_engine.h | 21 ++++++---- 8 files changed, 122 insertions(+), 103 deletions(-) (limited to 'src/smt') diff --git a/src/smt/command.h b/src/smt/command.h index b823b5730..41776cee5 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -46,7 +46,10 @@ class Term; class SmtEngine; class Command; class CommandStatus; + +namespace smt { class Model; +} std::ostream& operator<<(std::ostream&, const Command&) CVC4_PUBLIC; std::ostream& operator<<(std::ostream&, const Command*) CVC4_PUBLIC; @@ -995,7 +998,7 @@ class CVC4_PUBLIC GetModelCommand : public Command OutputLanguage language = language::output::LANG_AUTO) const override; protected: - Model* d_result; + smt::Model* d_result; }; /* class GetModelCommand */ /** The command to block models. */ diff --git a/src/smt/model.cpp b/src/smt/model.cpp index 60640def1..fc9ea8fbb 100644 --- a/src/smt/model.cpp +++ b/src/smt/model.cpp @@ -14,8 +14,6 @@ #include "smt/model.h" -#include - #include "expr/expr_iomanip.h" #include "options/base_options.h" #include "printer/printer.h" @@ -23,10 +21,16 @@ #include "smt/node_command.h" #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" - -using namespace std; +#include "theory/theory_model.h" namespace CVC4 { +namespace smt { + +Model::Model(SmtEngine& smt, theory::TheoryModel* tm) + : d_smt(smt), d_isKnownSat(false), d_tmodel(tm) +{ + Assert(d_tmodel != nullptr); +} std::ostream& operator<<(std::ostream& out, const Model& m) { smt::SmtScope smts(&m.d_smt); @@ -35,8 +39,6 @@ std::ostream& operator<<(std::ostream& out, const Model& m) { return out; } -Model::Model() : d_smt(*smt::currentSmtEngine()), d_isKnownSat(false) {} - size_t Model::getNumCommands() const { return d_smt.getDumpManager()->getNumModelCommands(); @@ -47,4 +49,17 @@ const NodeCommand* Model::getCommand(size_t i) const return d_smt.getDumpManager()->getModelCommand(i); } +theory::TheoryModel* Model::getTheoryModel() { return d_tmodel; } + +const theory::TheoryModel* Model::getTheoryModel() const { return d_tmodel; } + +bool Model::isModelCoreSymbol(TNode sym) const +{ + return d_tmodel->isModelCoreSymbol(sym); +} +Node Model::getValue(TNode n) const { return d_tmodel->getValue(n); } + +bool Model::hasApproximations() const { return d_tmodel->hasApproximations(); } + +} // namespace smt }/* CVC4 namespace */ diff --git a/src/smt/model.h b/src/smt/model.h index eb959ba7e..dc36b5d29 100644 --- a/src/smt/model.h +++ b/src/smt/model.h @@ -21,30 +21,34 @@ #include #include "expr/expr.h" +#include "theory/theory_model.h" #include "util/cardinality.h" namespace CVC4 { -class NodeCommand; class SmtEngine; +class NodeCommand; + +namespace smt { + class Model; std::ostream& operator<<(std::ostream&, const Model&); +/** + * This is the SMT-level model object, that is responsible for maintaining + * the necessary information for how to print the model, as well as + * holding a pointer to the underlying implementation of the theory model. + */ class Model { friend std::ostream& operator<<(std::ostream&, const Model&); - friend class SmtEngine; - - protected: - /** The SmtEngine we're associated with */ - SmtEngine& d_smt; - - /** construct the base class; users cannot do this, only CVC4 internals */ - Model(); + friend class ::CVC4::SmtEngine; public: + /** construct */ + Model(SmtEngine& smt, theory::TheoryModel* tm); /** virtual destructor */ - virtual ~Model() { } + ~Model() {} /** get number of commands to report */ size_t getNumCommands() const; /** get command */ @@ -62,54 +66,21 @@ class Model { * only a candidate solution. */ bool isKnownSat() const { return d_isKnownSat; } - //--------------------------- model cores - /** set using model core - * - * This sets that this model is minimized to be a "model core" for some - * formula (typically the input formula). - * - * For example, given formula ( a>5 OR b>5 ) AND f( c ) = 0, - * a model for this formula is: a -> 6, b -> 0, c -> 0, f -> lambda x. 0. - * A "model core" is a subset of this model that suffices to show the - * above formula is true, for example { a -> 6, f -> lambda x. 0 } is a - * model core for this formula. - */ - virtual void setUsingModelCore() = 0; - /** record model core symbol - * - * This marks that sym is a "model core symbol". In other words, its value is - * critical to the satisfiability of the formula this model is for. - */ - virtual void recordModelCoreSymbol(Expr sym) = 0; - /** Check whether this expr is in the model core */ - virtual bool isModelCoreSymbol(Expr expr) const = 0; - //--------------------------- end model cores - /** get value for expression */ - virtual Expr getValue(Expr expr) const = 0; - /** get cardinality for sort */ - virtual Cardinality getCardinality(Type t) const = 0; - /** print comments */ - virtual void getComments(std::ostream& out) const {} - /** get heap model (for separation logic) */ - virtual bool getHeapModel( Expr& h, Expr& ne ) const { return false; } - /** are there any approximations in this model? */ - virtual bool hasApproximations() const { return false; } - /** get the list of approximations - * - * This is a list of pairs of the form (t,p), where t is a term and p - * is a predicate over t that indicates a property that t satisfies. - */ - virtual std::vector > getApproximations() const = 0; - /** get the domain elements for uninterpreted sort t - * - * This method gets the interpretation of an uninterpreted sort t. - * All models interpret uninterpreted sorts t as finite sets - * of domain elements v_1, ..., v_n. This method returns this list for t in - * this model. - */ - virtual std::vector getDomainElements(Type t) const = 0; - + /** Get the underlying theory model */ + theory::TheoryModel* getTheoryModel(); + /** Get the underlying theory model (const version) */ + const theory::TheoryModel* getTheoryModel() const; + //----------------------- helper methods in the underlying theory model + /** Is the node n a model core symbol? */ + bool isModelCoreSymbol(TNode sym) const; + /** Get value */ + Node getValue(TNode n) const; + /** Does this model have approximations? */ + bool hasApproximations() const; + //----------------------- end helper methods protected: + /** The SmtEngine we're associated with */ + SmtEngine& d_smt; /** the input name (file name, etc.) this model is associated to */ std::string d_inputName; /** @@ -117,8 +88,14 @@ class Model { * from the solver. */ bool d_isKnownSat; -};/* class Model */ + /** + * Pointer to the underlying theory model, which maintains all data regarding + * the values of sorts and terms. + */ + theory::TheoryModel* d_tmodel; +}; +} // namespace smt }/* CVC4 namespace */ #endif /* CVC4__MODEL_H */ diff --git a/src/smt/model_blocker.cpp b/src/smt/model_blocker.cpp index 9d15b5690..cabd7bd20 100644 --- a/src/smt/model_blocker.cpp +++ b/src/smt/model_blocker.cpp @@ -66,7 +66,7 @@ Node ModelBlocker::getModelBlocker(const std::vector& assertions, Node blockTriv = nm->mkConst(false); Trace("model-blocker") << "...model blocker is (trivially) " << blockTriv << std::endl; - return blockTriv.toExpr(); + return blockTriv; } Node formula = asserts.size() > 1 ? nm->mkNode(AND, asserts) : asserts[0]; @@ -152,7 +152,7 @@ Node ModelBlocker::getModelBlocker(const std::vector& assertions, std::vector children; for (const Node& cn : catom) { - Node vn = Node::fromExpr(m->getValue(cn.toExpr())); + Node vn = m->getValue(cn); Assert(vn.isConst()); children.push_back(vn.getConst() ? cn : cn.negate()); } diff --git a/src/smt/model_core_builder.cpp b/src/smt/model_core_builder.cpp index 59dac63e8..cb8494e85 100644 --- a/src/smt/model_core_builder.cpp +++ b/src/smt/model_core_builder.cpp @@ -21,7 +21,7 @@ using namespace CVC4::kind; namespace CVC4 { bool ModelCoreBuilder::setModelCore(const std::vector& assertions, - Model* m, + theory::TheoryModel* m, options::ModelCoresMode mode) { if (Trace.isOn("model-core")) @@ -53,7 +53,7 @@ bool ModelCoreBuilder::setModelCore(const std::vector& assertions, visited.insert(cur); if (cur.isVar()) { - Node vcur = Node::fromExpr(m->getValue(cur.toExpr())); + Node vcur = m->getValue(cur); Trace("model-core") << " " << cur << " -> " << vcur << std::endl; vars.push_back(cur); subs.push_back(vcur); @@ -95,7 +95,7 @@ bool ModelCoreBuilder::setModelCore(const std::vector& assertions, for (const Node& cv : coreVars) { - m->recordModelCoreSymbol(cv.toExpr()); + m->recordModelCoreSymbol(cv); } return true; } diff --git a/src/smt/model_core_builder.h b/src/smt/model_core_builder.h index 984c61d04..7a28c47f2 100644 --- a/src/smt/model_core_builder.h +++ b/src/smt/model_core_builder.h @@ -21,7 +21,7 @@ #include "expr/expr.h" #include "options/smt_options.h" -#include "smt/model.h" +#include "theory/theory_model.h" namespace CVC4 { @@ -55,7 +55,7 @@ class ModelCoreBuilder * left unchanged. */ static bool setModelCore(const std::vector& assertions, - Model* m, + theory::TheoryModel* m, options::ModelCoresMode mode); }; /* class TheoryModelCoreBuilder */ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 205865e16..2a771ce76 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -127,6 +127,7 @@ SmtEngine::SmtEngine(ExprManager* em, Options* optr) d_snmListener(new SmtNodeManagerListener(*d_dumpm.get(), d_outMgr)), d_smtSolver(nullptr), d_proofManager(nullptr), + d_model(nullptr), d_pfManager(nullptr), d_rewriter(new theory::Rewriter()), d_definedFunctions(nullptr), @@ -271,6 +272,15 @@ void SmtEngine::finishInit() Trace("smt-debug") << "SmtEngine::finishInit" << std::endl; d_smtSolver->finishInit(const_cast(d_logic)); + // now can construct the SMT-level model object + TheoryEngine* te = d_smtSolver->getTheoryEngine(); + Assert(te != nullptr); + TheoryModel* tm = te->getModel(); + if (tm != nullptr) + { + d_model.reset(new Model(*this, tm)); + } + // global push/pop around everything, to ensure proper destruction // of context-dependent data structures d_state->setup(); @@ -839,7 +849,7 @@ Result SmtEngine::quickCheck() { Result::ENTAILMENT_UNKNOWN, Result::REQUIRES_FULL_CHECK, filename); } -theory::TheoryModel* SmtEngine::getAvailableModel(const char* c) const +Model* SmtEngine::getAvailableModel(const char* c) const { if (!options::assignFunctionValues()) { @@ -878,7 +888,7 @@ theory::TheoryModel* SmtEngine::getAvailableModel(const char* c) const throw RecoverableModalException(ss.str().c_str()); } - return m; + return d_model.get(); } void SmtEngine::notifyPushPre() { d_smtSolver->processAssertions(*d_asserts); } @@ -1210,11 +1220,9 @@ Node SmtEngine::getValue(const Node& ex) const } Trace("smt") << "--- getting value of " << n << endl; - TheoryModel* m = getAvailableModel("get-value"); - Node resultNode; - if(m != NULL) { - resultNode = m->getValue(n); - } + Model* m = getAvailableModel("get-value"); + Assert(m != nullptr); + Node resultNode = m->getValue(n); Trace("smt") << "--- got value " << n << " = " << resultNode << endl; Trace("smt") << "--- type " << resultNode.getType() << endl; Trace("smt") << "--- expected type " << expectedType << endl; @@ -1301,7 +1309,7 @@ vector> SmtEngine::getAssignment() // Get the model here, regardless of whether d_assignments is null, since // we should throw errors related to model availability whether or not // assignments is null. - TheoryModel* m = getAvailableModel("get assignment"); + Model* m = getAvailableModel("get assignment"); vector> res; if (d_assignments != nullptr) @@ -1354,7 +1362,7 @@ Model* SmtEngine::getModel() { getOutputManager().getDumpOut()); } - TheoryModel* m = getAvailableModel("get model"); + Model* m = getAvailableModel("get model"); // Since model m is being returned to the user, we must ensure that this // model object remains valid with future check-sat calls. Hence, we set @@ -1368,8 +1376,11 @@ Model* SmtEngine::getModel() { // If we enabled model cores, we compute a model core for m based on our // (expanded) assertions using the model core builder utility std::vector eassertsProc = getExpandedAssertions(); - ModelCoreBuilder::setModelCore(eassertsProc, m, options::modelCoresMode()); + ModelCoreBuilder::setModelCore( + eassertsProc, m->getTheoryModel(), options::modelCoresMode()); } + // set the information on the SMT-level model + Assert(m != nullptr); m->d_inputName = d_state->getFilename(); m->d_isKnownSat = (d_state->getMode() == SmtMode::SAT); return m; @@ -1388,19 +1399,19 @@ Result SmtEngine::blockModel() getOutputManager().getDumpOut()); } - TheoryModel* m = getAvailableModel("block model"); + Model* m = getAvailableModel("block model"); if (options::blockModelsMode() == options::BlockModelsMode::NONE) { std::stringstream ss; ss << "Cannot block model when block-models is set to none."; - throw ModalException(ss.str().c_str()); + throw RecoverableModalException(ss.str().c_str()); } // get expanded assertions std::vector eassertsProc = getExpandedAssertions(); Node eblocker = ModelBlocker::getModelBlocker( - eassertsProc, m, options::blockModelsMode()); + eassertsProc, m->getTheoryModel(), options::blockModelsMode()); return assertFormula(eblocker); } @@ -1417,13 +1428,16 @@ Result SmtEngine::blockModelValues(const std::vector& exprs) getOutputManager().getDumpOut(), exprs); } - TheoryModel* m = getAvailableModel("block model values"); + Model* m = getAvailableModel("block model values"); // get expanded assertions std::vector eassertsProc = getExpandedAssertions(); // we always do block model values mode here - Node eblocker = ModelBlocker::getModelBlocker( - eassertsProc, m, options::BlockModelsMode::VALUES, exprs); + Node eblocker = + ModelBlocker::getModelBlocker(eassertsProc, + m->getTheoryModel(), + options::BlockModelsMode::VALUES, + exprs); return assertFormula(eblocker); } @@ -1437,16 +1451,18 @@ std::pair SmtEngine::getSepHeapAndNilExpr(void) throw RecoverableModalException(msg); } NodeManagerScope nms(d_nodeManager); - Expr heap; - Expr nil; + Node heap; + Node nil; Model* m = getAvailableModel("get separation logic heap and nil"); - if (!m->getHeapModel(heap, nil)) + TheoryModel* tm = m->getTheoryModel(); + if (!tm->getHeapModel(heap, nil)) { - InternalError() - << "SmtEngine::getSepHeapAndNilExpr(): failed to obtain heap/nil " - "expressions from theory model."; + const char* msg = + "Failed to obtain heap/nil " + "expressions from theory model."; + throw RecoverableModalException(msg); } - return std::make_pair(Node::fromExpr(heap), Node::fromExpr(nil)); + return std::make_pair(heap, nil); } std::vector SmtEngine::getExpandedAssertions() @@ -1544,7 +1560,8 @@ void SmtEngine::checkModel(bool hardFailure) { // and if Notice() is on, the user gave --verbose (or equivalent). Notice() << "SmtEngine::checkModel(): generating model" << endl; - TheoryModel* m = getAvailableModel("check model"); + Model* m = getAvailableModel("check model"); + Assert(m != nullptr); // check-model is not guaranteed to succeed if approximate values were used. // Thus, we intentionally abort here. diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 62e54a0c1..da12d336b 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -60,7 +60,6 @@ class TheoryEngine; class ProofManager; -class Model; class LogicRequest; class StatisticsRegistry; @@ -95,6 +94,7 @@ namespace prop { namespace smt { /** Utilities */ +class Model; class SmtEngineState; class AbstractValues; class Assertions; @@ -280,7 +280,7 @@ class CVC4_PUBLIC SmtEngine * Get the model (only if immediately preceded by a SAT or NOT_ENTAILED * query). Only permitted if produce-models is on. */ - Model* getModel(); + smt::Model* getModel(); /** * Block the current model. Can be called only if immediately preceded by @@ -969,16 +969,17 @@ class CVC4_PUBLIC SmtEngine Result quickCheck(); /** - * Get the model, if it is available and return a pointer to it + * Get the (SMT-level) model pointer, if we are in SAT mode. Otherwise, + * return nullptr. * - * This ensures that the model is currently available, which means that - * CVC4 is producing models, and is in "SAT mode", otherwise an exception - * is thrown. + * This ensures that the underlying theory model of the SmtSolver maintained + * by this class is currently available, which means that CVC4 is producing + * models, and is in "SAT mode", otherwise a recoverable exception is thrown. * * The flag c is used for giving an error message to indicate the context * this method was called. */ - theory::TheoryModel* getAvailableModel(const char* c) const; + smt::Model* getAvailableModel(const char* c) const; // --------------------------------------- callbacks from the state /** @@ -1088,6 +1089,12 @@ class CVC4_PUBLIC SmtEngine /** The (old) proof manager TODO (project #37): delete this */ std::unique_ptr d_proofManager; + /** + * The SMT-level model object, which contains information about how to + * print the model, as well as a pointer to the underlying TheoryModel + * implementation maintained by the SmtSolver. + */ + std::unique_ptr d_model; /** * The proof manager, which manages all things related to checking, -- cgit v1.2.3