diff options
Diffstat (limited to 'src/preprocessing')
-rw-r--r-- | src/preprocessing/passes/bv_to_int.cpp | 11 | ||||
-rw-r--r-- | src/preprocessing/passes/real_to_int.cpp | 29 | ||||
-rw-r--r-- | src/preprocessing/passes/real_to_int.h | 7 | ||||
-rw-r--r-- | src/preprocessing/passes/sygus_inference.cpp | 18 | ||||
-rw-r--r-- | src/preprocessing/preprocessing_pass_context.cpp | 9 | ||||
-rw-r--r-- | src/preprocessing/preprocessing_pass_context.h | 13 |
6 files changed, 53 insertions, 34 deletions
diff --git a/src/preprocessing/passes/bv_to_int.cpp b/src/preprocessing/passes/bv_to_int.cpp index 6fe676e30..c725081c2 100644 --- a/src/preprocessing/passes/bv_to_int.cpp +++ b/src/preprocessing/passes/bv_to_int.cpp @@ -840,8 +840,15 @@ void BVToInt::defineBVUFAsIntUF(Node bvUF, Node intUF) } // If the result is BV, it needs to be casted back. result = castToType(result, resultType); - // add the function definition to the smt engine. - d_preprocContext->getSmt()->defineFunction(bvUF, args, result, true); + // add the substitution to the preprocessing context, which ensures the + // model for bvUF is correct, as well as substituting it in the input + // assertions when necessary. + if (!args.empty()) + { + result = d_nm->mkNode( + kind::LAMBDA, d_nm->mkNode(kind::BOUND_VAR_LIST, args), result); + } + d_preprocContext->addSubstitution(bvUF, result); } bool BVToInt::childrenTypesChanged(Node n) diff --git a/src/preprocessing/passes/real_to_int.cpp b/src/preprocessing/passes/real_to_int.cpp index 7c4097564..9e84dc851 100644 --- a/src/preprocessing/passes/real_to_int.cpp +++ b/src/preprocessing/passes/real_to_int.cpp @@ -26,12 +26,17 @@ #include "theory/rewriter.h" #include "theory/theory_model.h" +using namespace cvc5::theory; + namespace cvc5 { namespace preprocessing { namespace passes { -using namespace std; -using namespace cvc5::theory; +RealToInt::RealToInt(PreprocessingPassContext* preprocContext) + : PreprocessingPass(preprocContext, "real-to-int"), + d_cache(preprocContext->getUserContext()) +{ +} Node RealToInt::realToIntInternal(TNode n, NodeMap& cache, std::vector<Node>& var_eq) { @@ -181,15 +186,13 @@ Node RealToInt::realToIntInternal(TNode n, NodeMap& cache, std::vector<Node>& va } else if (n.isVar()) { - ret = sm->mkDummySkolem( - "__realToIntInternal_var", - nm->integerType(), - "Variable introduced in realToIntInternal pass"); + Node toIntN = nm->mkNode(kind::TO_INTEGER, n); + ret = sm->mkPurifySkolem(toIntN, "__realToIntInternal_var"); var_eq.push_back(n.eqNode(ret)); - // ensure that the original variable is defined to be the returned - // one, which is important for models and for incremental solving. - std::vector<Node> args; - d_preprocContext->getSmt()->defineFunction(n, args, ret); + // add the substitution to the preprocessing context, which ensures + // the model for n is correct, as well as substituting it in the input + // assertions when necessary. + d_preprocContext->addSubstitution(n, ret); } } } @@ -198,18 +201,14 @@ Node RealToInt::realToIntInternal(TNode n, NodeMap& cache, std::vector<Node>& va } } -RealToInt::RealToInt(PreprocessingPassContext* preprocContext) - : PreprocessingPass(preprocContext, "real-to-int"){}; - PreprocessingPassResult RealToInt::applyInternal( AssertionPipeline* assertionsToPreprocess) { - unordered_map<Node, Node, NodeHashFunction> cache; std::vector<Node> var_eq; for (unsigned i = 0, size = assertionsToPreprocess->size(); i < size; ++i) { assertionsToPreprocess->replace( - i, realToIntInternal((*assertionsToPreprocess)[i], cache, var_eq)); + i, realToIntInternal((*assertionsToPreprocess)[i], d_cache, var_eq)); } return PreprocessingPassResult::NO_CONFLICT; } diff --git a/src/preprocessing/passes/real_to_int.h b/src/preprocessing/passes/real_to_int.h index 9f0eb529f..d26547372 100644 --- a/src/preprocessing/passes/real_to_int.h +++ b/src/preprocessing/passes/real_to_int.h @@ -22,6 +22,7 @@ #include <vector> +#include "context/cdhashmap.h" #include "expr/node.h" #include "preprocessing/preprocessing_pass.h" @@ -29,10 +30,10 @@ namespace cvc5 { namespace preprocessing { namespace passes { -using NodeMap = std::unordered_map<Node, Node, NodeHashFunction>; - class RealToInt : public PreprocessingPass { + using NodeMap = context::CDHashMap<Node, Node, NodeHashFunction>; + public: RealToInt(PreprocessingPassContext* preprocContext); @@ -42,6 +43,8 @@ class RealToInt : public PreprocessingPass private: Node realToIntInternal(TNode n, NodeMap& cache, std::vector<Node>& var_eq); + /** Cache for the above method */ + NodeMap d_cache; }; } // namespace passes diff --git a/src/preprocessing/passes/sygus_inference.cpp b/src/preprocessing/passes/sygus_inference.cpp index 870ad6625..b15d5a377 100644 --- a/src/preprocessing/passes/sygus_inference.cpp +++ b/src/preprocessing/passes/sygus_inference.cpp @@ -47,23 +47,13 @@ PreprocessingPassResult SygusInference::applyInternal( // see if we can succesfully solve the input as a sygus problem if (solveSygus(assertionsToPreprocess->ref(), funs, sols)) { + Trace("sygus-infer") << "...Solved:" << std::endl; Assert(funs.size() == sols.size()); - // if so, sygus gives us function definitions - SmtEngine* master_smte = d_preprocContext->getSmt(); + // if so, sygus gives us function definitions, which we add as substitutions for (unsigned i = 0, size = funs.size(); i < size; i++) { - std::vector<Node> args; - Node sol = sols[i]; - // if it is a non-constant function - if (sol.getKind() == LAMBDA) - { - for (const Node& v : sol[0]) - { - args.push_back(v); - } - sol = sol[1]; - } - master_smte->defineFunction(funs[i], args, sol); + Trace("sygus-infer") << funs[i] << " -> " << sols[i] << std::endl; + d_preprocContext->addSubstitution(funs[i], sols[i]); } // apply substitution to everything, should result in SAT diff --git a/src/preprocessing/preprocessing_pass_context.cpp b/src/preprocessing/preprocessing_pass_context.cpp index 22cc15a97..4be6b4aac 100644 --- a/src/preprocessing/preprocessing_pass_context.cpp +++ b/src/preprocessing/preprocessing_pass_context.cpp @@ -63,6 +63,15 @@ void PreprocessingPassContext::addModelSubstitution(const Node& lhs, lhs, d_smt->expandDefinitions(rhs, false)); } +void PreprocessingPassContext::addSubstitution(const Node& lhs, + const Node& rhs, + ProofGenerator* pg) +{ + d_topLevelSubstitutions.addSubstitution(lhs, rhs, pg); + // also add as a model substitution + addModelSubstitution(lhs, rhs); +} + ProofNodeManager* PreprocessingPassContext::getProofNodeManager() { return d_pnm; diff --git a/src/preprocessing/preprocessing_pass_context.h b/src/preprocessing/preprocessing_pass_context.h index f1d92e864..a7e9b0deb 100644 --- a/src/preprocessing/preprocessing_pass_context.h +++ b/src/preprocessing/preprocessing_pass_context.h @@ -81,11 +81,22 @@ class PreprocessingPassContext void recordSymbolsInAssertions(const std::vector<Node>& assertions); /** - * Add substitution to theory model. + * Add substitution to theory model. This method should only be called if + * we have already added the substitution to the top-level substitutions + * class. Otherwise, addSubstitution should be called instead. * @param lhs The node replaced by node 'rhs' * @param rhs The node to substitute node 'lhs' */ void addModelSubstitution(const Node& lhs, const Node& rhs); + /** + * Add substitution to the top-level substitutions and to the theory model. + * @param lhs The node replaced by node 'rhs' + * @param rhs The node to substitute node 'lhs' + * @param pg The proof generator that can provide a proof of lhs == rhs. + */ + void addSubstitution(const Node& lhs, + const Node& rhs, + ProofGenerator* pg = nullptr); /** The the proof node manager associated with this context, if it exists */ ProofNodeManager* getProofNodeManager(); |