diff options
author | yoni206 <yoni206@gmail.com> | 2019-06-13 00:28:03 -0700 |
---|---|---|
committer | yoni206 <yoni206@gmail.com> | 2019-06-17 13:16:22 -0700 |
commit | 0b6676d97e32c8f340f0da03b0e4dc15bc913c98 (patch) | |
tree | df46ddb721ba3265551b10843d1a2fab0cfee813 | |
parent | 41639dee25fb5e03d8f48021637b5b38e9a30285 (diff) |
better interface
-rw-r--r-- | src/smt/command.cpp | 26 | ||||
-rw-r--r-- | src/smt/model_blocker.cpp | 8 | ||||
-rw-r--r-- | src/smt/model_blocker.h | 2 | ||||
-rw-r--r-- | src/smt/smt_engine.cpp | 35 | ||||
-rw-r--r-- | src/smt/smt_engine.h | 18 |
5 files changed, 58 insertions, 31 deletions
diff --git a/src/smt/command.cpp b/src/smt/command.cpp index 6b5c14f3f..f648dcdea 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -1662,16 +1662,22 @@ void GetValueCommand::invoke(SmtEngine* smtEngine) { try { - vector<Expr> result; ExprManager* em = smtEngine->getExprManager(); NodeManager* nm = NodeManager::fromExprManager(em); - for (const Expr& e : d_terms) + smt::SmtScope scope(smtEngine); + vector<Node> termNodes; + for (Expr e : d_terms) { + termNodes.push_back(Node::fromExpr(e)); + } + vector<Node> result = smtEngine->getValues(termNodes); + Assert(result.size() == d_terms.size()); + for (int i=0; i < d_terms.size(); i++) { + Expr e = d_terms[i]; Assert(nm == NodeManager::fromExprManager(e.getExprManager())); - smt::SmtScope scope(smtEngine); - Node request = Node::fromExpr( - options::expandDefinitions() ? smtEngine->expandDefinitions(e) : e); - Node value = Node::fromExpr(smtEngine->getValue(e, true)); + Node request = Node::fromExpr( options::expandDefinitions() + ? smtEngine->expandDefinitions(e) : e); + Node value = result[i]; if (value.getType().isInteger() && request.getType() == nm->realType()) { // Need to wrap in division-by-one so that output printers know this @@ -1679,9 +1685,13 @@ void GetValueCommand::invoke(SmtEngine* smtEngine) // a rational. Necessary for SMT-LIB standards compliance. value = nm->mkNode(kind::DIVISION, value, nm->mkConst(Rational(1))); } - result.push_back(nm->mkNode(kind::SEXPR, request, value).toExpr()); + result[i] = nm->mkNode(kind::SEXPR, request, value); + } + std::vector<Expr> resultExpr; + for (Node n : result) { + resultExpr.push_back(n.toExpr()); } - d_result = em->mkExpr(kind::SEXPR, result); + d_result = em->mkExpr(kind::SEXPR, resultExpr); d_commandStatus = CommandSuccess::instance(); } catch (RecoverableModalException& e) diff --git a/src/smt/model_blocker.cpp b/src/smt/model_blocker.cpp index 0e7c76f89..b87933f53 100644 --- a/src/smt/model_blocker.cpp +++ b/src/smt/model_blocker.cpp @@ -25,7 +25,7 @@ namespace CVC4 { Expr ModelBlocker::getModelBlocker(const std::vector<Expr>& assertions, theory::TheoryModel* m, BlockModelsMode mode, - std::vector<Node> getValueNodes) + const std::vector<Node>* nodesToBlock) { NodeManager* nm = NodeManager::currentNM(); // convert to nodes @@ -38,7 +38,7 @@ Expr ModelBlocker::getModelBlocker(const std::vector<Expr>& assertions, Trace("model-blocker") << "Compute model blocker, assertions:" << std::endl; Node blocker; if (mode == BLOCK_MODELS_LITERALS) { - Assert(getValueNodes.size() == 0); + Assert(nodesToBlock = NULL); // optimization: filter to only top-level disjunctions unsigned counter = 0; std::vector<Node> asserts; @@ -231,7 +231,7 @@ Expr ModelBlocker::getModelBlocker(const std::vector<Expr>& assertions, Assert(mode == BLOCK_MODELS_VALUES); std::vector<Node> blockers; //if specific terms were not specified in get-value, block all variables of the model - if (getValueNodes.size() == 0) { + if (nodesToBlock == NULL) { Trace("model-blocker") << "no get-value recognized" << std::endl; std::unordered_set<Node, NodeHashFunction> symbols; for (Node n: tlAsserts) { @@ -248,7 +248,7 @@ Expr ModelBlocker::getModelBlocker(const std::vector<Expr>& assertions, //otherwise, block all terms that were specified in get-value else { std::unordered_set<Node, NodeHashFunction> terms; - for (Node n : getValueNodes) { + for (Node n : *nodesToBlock) { Node v = m->getValue(n); Node a = nm->mkNode(DISTINCT, n, v); blockers.push_back(a); diff --git a/src/smt/model_blocker.h b/src/smt/model_blocker.h index 1e728b220..9507e07f3 100644 --- a/src/smt/model_blocker.h +++ b/src/smt/model_blocker.h @@ -52,7 +52,7 @@ class ModelBlocker * our input. In other words, we do not return ~(x < 0) V ~(w < 0) since the * left disjunct is always false. */ - static Expr getModelBlocker(const std::vector<Expr>& assertions, theory::TheoryModel* m, BlockModelsMode mode, std::vector<Node> getValueNodes); + static Expr getModelBlocker(const std::vector<Expr>& assertions, theory::TheoryModel* m, BlockModelsMode mode, const std::vector<Node>* nodesToBlock = NULL); }; /* class TheoryModelCoreBuilder */ } // namespace CVC4 diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 9acfa09e5..c79e6165f 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -876,7 +876,6 @@ SmtEngine::SmtEngine(ExprManager* em) d_assignments(NULL), d_modelGlobalCommands(), d_modelCommands(NULL), - d_getValueNodes(), d_dumpCommands(), d_defineCommands(), d_logic(), @@ -1075,7 +1074,6 @@ SmtEngine::~SmtEngine() delete d_dumpCommands[i]; d_dumpCommands[i] = NULL; } - d_getValueNodes.clear(); d_dumpCommands.clear(); DeleteAndClearCommandVector(d_modelGlobalCommands); @@ -4133,7 +4131,7 @@ Expr SmtEngine::expandDefinitions(const Expr& ex) } // TODO(#1108): Simplify the error reporting of this method. -Expr SmtEngine::getValue(const Expr& ex, bool isCommand) const +Expr SmtEngine::getValue(const Expr& ex) const { Assert(ex.getExprManager() == d_exprManager); SmtScope smts(this); @@ -4207,11 +4205,33 @@ Expr SmtEngine::getValue(const Expr& ex, bool isCommand) const Trace("smt") << "--- abstract value >> " << resultNode << endl; } - if (options::blockModelsMode() != BLOCK_MODELS_NONE && isCommand) + return resultNode.toExpr(); +} + +vector<Node> SmtEngine::getValues(const vector<Node> nodes) { + vector<Node> result; + for (Node n : nodes) { + Node value = Node::fromExpr(getValue(n.toExpr())); + result.push_back(value); + } + + if (options::blockModelsMode() != BLOCK_MODELS_NONE) { - d_getValueNodes.push_back(n); + TheoryModel* m = d_theoryEngine->getBuiltModel(); + std::vector<Expr> easserts = getAssertions(); + // must expand definitions + std::vector<Expr> eassertsProc; + std::unordered_map<Node, Node, NodeHashFunction> cache; + for (unsigned i = 0, nasserts = easserts.size(); i < nasserts; i++) + { + Node ea = Node::fromExpr(easserts[i]); + Node eae = d_private->expandDefinitions(ea, cache); + eassertsProc.push_back(eae.toExpr()); + } + Expr eblocker = ModelBlocker::getModelBlocker(eassertsProc, m, options::blockModelsMode(), &nodes); + assertFormula(eblocker); } - return resultNode.toExpr(); + return result; } bool SmtEngine::addToAssignment(const Expr& ex) { @@ -4399,9 +4419,8 @@ Model* SmtEngine::getModel() { } if (options::blockModelsMode() != BLOCK_MODELS_NONE) { - Expr eblocker = ModelBlocker::getModelBlocker(eassertsProc, m, options::blockModelsMode(), d_getValueNodes); + Expr eblocker = ModelBlocker::getModelBlocker(eassertsProc, m, options::blockModelsMode()); assertFormula(eblocker); - d_getValueNodes.clear(); } } m->d_inputName = d_filename; diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 7eb5656e7..3e6bdf9a6 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -185,15 +185,6 @@ class CVC4_PUBLIC SmtEngine { */ smt::CommandList* d_modelCommands; - - /** - * If there is a (get-value (t1 t2 ... tn)) command in the current - * context, then this vector includes t1,...,tn. - * This field is changed in getValue function, which is const, - * therefore it is declared mutable. - */ - mutable std::vector<Node> d_getValueNodes; - /** * A vector of declaration commands waiting to be dumped out. * Once the SmtEngine is fully initialized, we'll dump them. @@ -747,10 +738,17 @@ class CVC4_PUBLIC SmtEngine { * set to operate interactively and produce-models is on. * if isCommand == true then this call came from a get-value command */ - Expr getValue(const Expr& e, bool isCommand = false) const + Expr getValue(const Expr& e) const /* throw(ModalException, TypeCheckingException, LogicException, UnsafeInterruptException) */ ; + + /** + * Same as getValue but for a vector of expressions + */ + std::vector<Node> getValues(const std::vector<Node> nodes); + + /** * Add a function to the set of expressions whose value is to be * later returned by a call to getAssignment(). The expression |