diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2018-08-20 12:21:37 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-20 12:21:37 -0500 |
commit | 991af9a7a73adaa84712e93af72980ba977b1155 (patch) | |
tree | 4efe10ae3cb7f76acd25f8749859d76a2c8c1e80 | |
parent | b7dfffd3fe57ab8bf2b6f8aed35f8c3bb459a117 (diff) |
Make sygus inference a preprocessing pass (#2334)
-rw-r--r-- | src/Makefile.am | 4 | ||||
-rw-r--r-- | src/preprocessing/passes/sygus_inference.cpp (renamed from src/theory/quantifiers/sygus_inference.cpp) | 117 | ||||
-rw-r--r-- | src/preprocessing/passes/sygus_inference.h (renamed from src/theory/quantifiers/sygus_inference.h) | 52 | ||||
-rw-r--r-- | src/smt/smt_engine.cpp | 13 |
4 files changed, 109 insertions, 77 deletions
diff --git a/src/Makefile.am b/src/Makefile.am index 5e52186b9..40aa1a5af 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -99,6 +99,8 @@ libcvc4_la_SOURCES = \ preprocessing/passes/sort_infer.h \ preprocessing/passes/static_learning.cpp \ preprocessing/passes/static_learning.h \ + preprocessing/passes/sygus_inference.cpp \ + preprocessing/passes/sygus_inference.h \ preprocessing/passes/symmetry_breaker.cpp \ preprocessing/passes/symmetry_breaker.h \ preprocessing/passes/symmetry_detect.cpp \ @@ -542,8 +544,6 @@ libcvc4_la_SOURCES = \ theory/quantifiers/sygus/sygus_unif_strat.h \ theory/quantifiers/sygus/term_database_sygus.cpp \ theory/quantifiers/sygus/term_database_sygus.h \ - theory/quantifiers/sygus_inference.cpp \ - theory/quantifiers/sygus_inference.h \ theory/quantifiers/sygus_sampler.cpp \ theory/quantifiers/sygus_sampler.h \ theory/quantifiers/term_database.cpp \ diff --git a/src/theory/quantifiers/sygus_inference.cpp b/src/preprocessing/passes/sygus_inference.cpp index 6232de6fe..eb8835623 100644 --- a/src/theory/quantifiers/sygus_inference.cpp +++ b/src/preprocessing/passes/sygus_inference.cpp @@ -9,28 +9,78 @@ ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim ** - ** \brief Implementation of sygus_inference + ** \brief Sygus inference module **/ -#include "theory/quantifiers/sygus_inference.h" +#include "preprocessing/passes/sygus_inference.h" + #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" #include "smt/smt_statistics_registry.h" #include "theory/quantifiers/quantifiers_attributes.h" #include "theory/quantifiers/quantifiers_rewriter.h" +using namespace std; using namespace CVC4::kind; namespace CVC4 { -namespace theory { -namespace quantifiers { +namespace preprocessing { +namespace passes { -SygusInference::SygusInference() {} +SygusInference::SygusInference(PreprocessingPassContext* preprocContext) + : PreprocessingPass(preprocContext, "sygus-infer"){}; -bool SygusInference::simplify(std::vector<Node>& assertions) +PreprocessingPassResult SygusInference::applyInternal( + AssertionPipeline* assertionsToPreprocess) { Trace("sygus-infer") << "Run sygus inference..." << std::endl; + std::vector<Node> funs; + std::vector<Node> sols; + // see if we can succesfully solve the input as a sygus problem + if (solveSygus(assertionsToPreprocess->ref(), funs, sols)) + { + Assert(funs.size() == sols.size()); + // if so, sygus gives us function definitions + SmtEngine* master_smte = smt::currentSmtEngine(); + for (unsigned i = 0, size = funs.size(); i < size; i++) + { + std::vector<Expr> args; + Node sol = sols[i]; + // if it is a non-constant function + if (sol.getKind() == LAMBDA) + { + for (const Node& v : sol[0]) + { + args.push_back(v.toExpr()); + } + sol = sol[1]; + } + master_smte->defineFunction(funs[i].toExpr(), args, sol.toExpr()); + } + // apply substitution to everything, should result in SAT + for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size; + i++) + { + Node prev = (*assertionsToPreprocess)[i]; + Node curr = + prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end()); + if (curr != prev) + { + curr = theory::Rewriter::rewrite(curr); + Trace("sygus-infer-debug") + << "...rewrote " << prev << " to " << curr << std::endl; + assertionsToPreprocess->replace(i, curr); + } + } + } + return PreprocessingPassResult::NO_CONFLICT; +} + +bool SygusInference::solveSygus(std::vector<Node>& assertions, + std::vector<Node>& funs, + std::vector<Node>& sols) +{ if (assertions.empty()) { Trace("sygus-infer") << "...fail: empty assertions." << std::endl; @@ -78,19 +128,19 @@ bool SygusInference::simplify(std::vector<Node>& assertions) std::map<TypeNode, unsigned> type_count; Node pas = as; // rewrite - pas = Rewriter::rewrite(pas); + pas = theory::Rewriter::rewrite(pas); Trace("sygus-infer") << "assertion : " << pas << std::endl; if (pas.getKind() == FORALL) { // preprocess the quantified formula - pas = quantifiers::QuantifiersRewriter::preprocess(pas); + pas = theory::quantifiers::QuantifiersRewriter::preprocess(pas); Trace("sygus-infer-debug") << " ...preprocessed to " << pas << std::endl; } if (pas.getKind() == FORALL) { // it must be a standard quantifier - QAttributes qa; - QuantAttributes::computeQuantAttributes(pas, qa); + theory::quantifiers::QAttributes qa; + theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa); if (!qa.isStandard()) { Trace("sygus-infer") @@ -215,7 +265,7 @@ bool SygusInference::simplify(std::vector<Node>& assertions) // sygus attribute to mark the conjecture as a sygus conjecture Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl; Node sygusVar = nm->mkSkolem("sygus", nm->booleanType()); - SygusAttribute ca; + theory::SygusAttribute ca; sygusVar.setAttribute(ca, true); Node instAttr = nm->mkNode(INST_ATTRIBUTE, sygusVar); Node instAttrList = nm->mkNode(INST_PATTERN_LIST, instAttr); @@ -227,7 +277,6 @@ bool SygusInference::simplify(std::vector<Node>& assertions) Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl; // make a separate smt call - SmtEngine* master_smte = smt::currentSmtEngine(); SmtEngine rrSygus(nm->toExprManager()); rrSygus.setLogic(smt::currentSmtEngine()->getLogicInfo()); rrSygus.assertFormula(body.toExpr()); @@ -249,7 +298,6 @@ bool SygusInference::simplify(std::vector<Node>& assertions) it != synth_sols.end(); ++it) { - Node lambda = Node::fromExpr(it->second); Trace("sygus-infer") << " synth sol : " << it->first << " -> " << it->second << std::endl; Node ffv = Node::fromExpr(it->first); @@ -259,44 +307,15 @@ bool SygusInference::simplify(std::vector<Node>& assertions) if (itffv != ff_var_to_ff.end()) { Node ff = itffv->second; - Expr body = it->second; - std::vector<Expr> args; - // if it is a non-constant function - if (lambda.getKind() == LAMBDA) - { - for (const Node& v : lambda[0]) - { - args.push_back(v.toExpr()); - } - body = it->second[1]; - } - Trace("sygus-infer") << "Define " << ff << " as " << it->second - << std::endl; - final_ff.push_back(ff); - final_ff_sol.push_back(it->second); - master_smte->defineFunction(ff.toExpr(), args, body); - } - } - - // apply substitution to everything, should result in SAT - for (unsigned i = 0, size = assertions.size(); i < size; i++) - { - Node prev = assertions[i]; - Node curr = assertions[i].substitute(final_ff.begin(), - final_ff.end(), - final_ff_sol.begin(), - final_ff_sol.end()); - if (curr != prev) - { - curr = Rewriter::rewrite(curr); - Trace("sygus-infer-debug") - << "...rewrote " << prev << " to " << curr << std::endl; - assertions[i] = curr; + Node body = Node::fromExpr(it->second); + Trace("sygus-infer") << "Define " << ff << " as " << body << std::endl; + funs.push_back(ff); + sols.push_back(body); } } return true; } -} /* CVC4::theory::quantifiers namespace */ -} /* CVC4::theory namespace */ -} /* CVC4 namespace */ +} // namespace passes +} // namespace preprocessing +} // namespace CVC4 diff --git a/src/theory/quantifiers/sygus_inference.h b/src/preprocessing/passes/sygus_inference.h index 414103fc7..5e7c6f7d0 100644 --- a/src/theory/quantifiers/sygus_inference.h +++ b/src/preprocessing/passes/sygus_inference.h @@ -9,20 +9,23 @@ ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim ** - ** \brief sygus_inference + ** \brief SygusInference **/ -#include "cvc4_private.h" - -#ifndef __CVC4__THEORY__QUANTIFIERS__SYGUS_INFERENCE_H -#define __CVC4__THEORY__QUANTIFIERS__SYGUS_INFERENCE_H +#ifndef __CVC4__PREPROCESSING__PASSES__SYGUS_INFERENCE_H_ +#define __CVC4__PREPROCESSING__PASSES__SYGUS_INFERENCE_H_ +#include <map> +#include <string> #include <vector> #include "expr/node.h" +#include "preprocessing/preprocessing_pass.h" +#include "preprocessing/preprocessing_pass_context.h" + namespace CVC4 { -namespace theory { -namespace quantifiers { +namespace preprocessing { +namespace passes { /** SygusInference * @@ -33,25 +36,36 @@ namespace quantifiers { * problem, thus obtaining a set of model substitutions under which the * assertions should simplify to true. */ -class SygusInference +class SygusInference : public PreprocessingPass { public: - SygusInference(); - ~SygusInference() {} - /** simplify assertions - * + SygusInference(PreprocessingPassContext* preprocContext); + + protected: + /** * Either replaces all uninterpreted functions in assertions by their - * interpretation in the solution found by a separate call to an SMT engine - * and returns true, or leaves the assertions unmodified and returns false. + * interpretation in a sygus solution, or leaves the assertions unmodified. + */ + PreprocessingPassResult applyInternal( + AssertionPipeline* assertionsToPreprocess) override; + /** solve sygus + * + * Returns true if we can recast the input problem assertions as a sygus + * problem and successfully solve it using a separate call to an SMT engine. * * We fail if either a sygus conjecture that corresponds to assertions cannot * be inferred, or the sygus conjecture we infer is infeasible. + * + * If this function returns true, then we add all uninterpreted symbols s in + * assertions to funs and their corresponding solution to sols. */ - bool simplify(std::vector<Node>& assertions); + bool solveSygus(std::vector<Node>& assertions, + std::vector<Node>& funs, + std::vector<Node>& sols); }; -} /* CVC4::theory::quantifiers namespace */ -} /* CVC4::theory namespace */ -} /* CVC4 namespace */ +} // namespace passes +} // namespace preprocessing +} // namespace CVC4 -#endif /* __CVC4__THEORY__QUANTIFIERS__SYGUS_INFERENCE_H */ +#endif /* __CVC4__PREPROCESSING__PASSES__SYGUS_INFERENCE_H_ */ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 418028d09..1e8ae4033 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -88,6 +88,7 @@ #include "preprocessing/passes/sep_skolem_emp.h" #include "preprocessing/passes/sort_infer.h" #include "preprocessing/passes/static_learning.h" +#include "preprocessing/passes/sygus_inference.h" #include "preprocessing/passes/symmetry_breaker.h" #include "preprocessing/passes/symmetry_detect.h" #include "preprocessing/passes/synth_rew_rules.h" @@ -119,7 +120,6 @@ #include "theory/quantifiers/quantifiers_rewriter.h" #include "theory/quantifiers/single_inv_partition.h" #include "theory/quantifiers/sygus/ce_guided_instantiation.h" -#include "theory/quantifiers/sygus_inference.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" #include "theory/sort_inference.h" @@ -2676,6 +2676,8 @@ void SmtEnginePrivate::finishInit() d_smt.d_theoryEngine->getSortInference())); std::unique_ptr<StaticLearning> staticLearning( new StaticLearning(d_preprocessingPassContext.get())); + std::unique_ptr<SygusInference> sygusInfer( + new SygusInference(d_preprocessingPassContext.get())); std::unique_ptr<SymBreakerPass> sbProc( new SymBreakerPass(d_preprocessingPassContext.get())); std::unique_ptr<SynthRewRulesPass> srrProc( @@ -2713,6 +2715,8 @@ void SmtEnginePrivate::finishInit() std::move(sortInfer)); d_preprocessingPassRegistry.registerPass("static-learning", std::move(staticLearning)); + d_preprocessingPassRegistry.registerPass("sygus-infer", + std::move(sygusInfer)); d_preprocessingPassRegistry.registerPass("sym-break", std::move(sbProc)); d_preprocessingPassRegistry.registerPass("synth-rr", std::move(srrProc)); } @@ -4243,12 +4247,7 @@ void SmtEnginePrivate::processAssertions() { } if (options::sygusInference()) { - // try recast as sygus - quantifiers::SygusInference si; - if (si.simplify(d_assertions.ref())) - { - Trace("smt-proc") << "...converted to sygus conjecture." << std::endl; - } + d_preprocessingPassRegistry.getPass("sygus-infer")->apply(&d_assertions); } Trace("smt-proc") << "SmtEnginePrivate::processAssertions() : post-quant-preprocess" << endl; } |