summaryrefslogtreecommitdiff
path: root/src/preprocessing
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2018-08-20 12:21:37 -0500
committerGitHub <noreply@github.com>2018-08-20 12:21:37 -0500
commit991af9a7a73adaa84712e93af72980ba977b1155 (patch)
tree4efe10ae3cb7f76acd25f8749859d76a2c8c1e80 /src/preprocessing
parentb7dfffd3fe57ab8bf2b6f8aed35f8c3bb459a117 (diff)
Make sygus inference a preprocessing pass (#2334)
Diffstat (limited to 'src/preprocessing')
-rw-r--r--src/preprocessing/passes/sygus_inference.cpp321
-rw-r--r--src/preprocessing/passes/sygus_inference.h71
2 files changed, 392 insertions, 0 deletions
diff --git a/src/preprocessing/passes/sygus_inference.cpp b/src/preprocessing/passes/sygus_inference.cpp
new file mode 100644
index 000000000..eb8835623
--- /dev/null
+++ b/src/preprocessing/passes/sygus_inference.cpp
@@ -0,0 +1,321 @@
+/********************* */
+/*! \file sygus_inference.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Sygus inference module
+ **/
+
+#include "preprocessing/passes/sygus_inference.h"
+
+#include "smt/smt_engine.h"
+#include "smt/smt_engine_scope.h"
+#include "smt/smt_statistics_registry.h"
+#include "theory/quantifiers/quantifiers_attributes.h"
+#include "theory/quantifiers/quantifiers_rewriter.h"
+
+using namespace std;
+using namespace CVC4::kind;
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+SygusInference::SygusInference(PreprocessingPassContext* preprocContext)
+ : PreprocessingPass(preprocContext, "sygus-infer"){};
+
+PreprocessingPassResult SygusInference::applyInternal(
+ AssertionPipeline* assertionsToPreprocess)
+{
+ Trace("sygus-infer") << "Run sygus inference..." << std::endl;
+ std::vector<Node> funs;
+ std::vector<Node> sols;
+ // see if we can succesfully solve the input as a sygus problem
+ if (solveSygus(assertionsToPreprocess->ref(), funs, sols))
+ {
+ Assert(funs.size() == sols.size());
+ // if so, sygus gives us function definitions
+ SmtEngine* master_smte = smt::currentSmtEngine();
+ for (unsigned i = 0, size = funs.size(); i < size; i++)
+ {
+ std::vector<Expr> args;
+ Node sol = sols[i];
+ // if it is a non-constant function
+ if (sol.getKind() == LAMBDA)
+ {
+ for (const Node& v : sol[0])
+ {
+ args.push_back(v.toExpr());
+ }
+ sol = sol[1];
+ }
+ master_smte->defineFunction(funs[i].toExpr(), args, sol.toExpr());
+ }
+
+ // apply substitution to everything, should result in SAT
+ for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size;
+ i++)
+ {
+ Node prev = (*assertionsToPreprocess)[i];
+ Node curr =
+ prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end());
+ if (curr != prev)
+ {
+ curr = theory::Rewriter::rewrite(curr);
+ Trace("sygus-infer-debug")
+ << "...rewrote " << prev << " to " << curr << std::endl;
+ assertionsToPreprocess->replace(i, curr);
+ }
+ }
+ }
+ return PreprocessingPassResult::NO_CONFLICT;
+}
+
+bool SygusInference::solveSygus(std::vector<Node>& assertions,
+ std::vector<Node>& funs,
+ std::vector<Node>& sols)
+{
+ if (assertions.empty())
+ {
+ Trace("sygus-infer") << "...fail: empty assertions." << std::endl;
+ return false;
+ }
+
+ NodeManager* nm = NodeManager::currentNM();
+
+ // collect free variables in all assertions
+ std::vector<Node> qvars;
+ std::map<TypeNode, std::vector<Node> > qtvars;
+ std::vector<Node> free_functions;
+
+ std::vector<TNode> visit;
+ std::unordered_set<TNode, TNodeHashFunction> visited;
+
+ // add top-level conjuncts to eassertions
+ std::vector<Node> assertions_proc = assertions;
+ std::vector<Node> eassertions;
+ unsigned index = 0;
+ while (index < assertions_proc.size())
+ {
+ Node ca = assertions_proc[index];
+ if (ca.getKind() == AND)
+ {
+ for (const Node& ai : ca)
+ {
+ assertions_proc.push_back(ai);
+ }
+ }
+ else
+ {
+ eassertions.push_back(ca);
+ }
+ index++;
+ }
+
+ // process eassertions
+ std::vector<Node> processed_assertions;
+ for (const Node& as : eassertions)
+ {
+ // substitution for this assertion
+ std::vector<Node> vars;
+ std::vector<Node> subs;
+ std::map<TypeNode, unsigned> type_count;
+ Node pas = as;
+ // rewrite
+ pas = theory::Rewriter::rewrite(pas);
+ Trace("sygus-infer") << "assertion : " << pas << std::endl;
+ if (pas.getKind() == FORALL)
+ {
+ // preprocess the quantified formula
+ pas = theory::quantifiers::QuantifiersRewriter::preprocess(pas);
+ Trace("sygus-infer-debug") << " ...preprocessed to " << pas << std::endl;
+ }
+ if (pas.getKind() == FORALL)
+ {
+ // it must be a standard quantifier
+ theory::quantifiers::QAttributes qa;
+ theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa);
+ if (!qa.isStandard())
+ {
+ Trace("sygus-infer")
+ << "...fail: non-standard top-level quantifier." << std::endl;
+ return false;
+ }
+ // infer prefix
+ for (const Node& v : pas[0])
+ {
+ TypeNode tnv = v.getType();
+ unsigned vnum = type_count[tnv];
+ type_count[tnv]++;
+ if (vnum < qtvars[tnv].size())
+ {
+ vars.push_back(v);
+ subs.push_back(qtvars[tnv][vnum]);
+ }
+ else
+ {
+ Assert(vnum == qtvars[tnv].size());
+ qtvars[tnv].push_back(v);
+ qvars.push_back(v);
+ }
+ }
+ pas = pas[1];
+ if (!vars.empty())
+ {
+ pas =
+ pas.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+ }
+ }
+ Trace("sygus-infer-debug") << " ...substituted to " << pas << std::endl;
+
+ // collect free functions, ensure no quantified formulas
+ TNode cur = pas;
+ // compute free variables
+ visit.push_back(cur);
+ do
+ {
+ cur = visit.back();
+ visit.pop_back();
+ if (visited.find(cur) == visited.end())
+ {
+ visited.insert(cur);
+ if (cur.getKind() == APPLY_UF)
+ {
+ Node op = cur.getOperator();
+ if (std::find(free_functions.begin(), free_functions.end(), op)
+ == free_functions.end())
+ {
+ free_functions.push_back(op);
+ }
+ }
+ else if (cur.getKind() == VARIABLE)
+ {
+ // a free variable is a 0-argument function to synthesize
+ Assert(std::find(free_functions.begin(), free_functions.end(), cur)
+ == free_functions.end());
+ free_functions.push_back(cur);
+ }
+ else if (cur.getKind() == FORALL)
+ {
+ Trace("sygus-infer")
+ << "...fail: non-top-level quantifier." << std::endl;
+ return false;
+ }
+ for (const TNode& cn : cur)
+ {
+ visit.push_back(cn);
+ }
+ }
+ } while (!visit.empty());
+ processed_assertions.push_back(pas);
+ }
+
+ // no functions to synthesize
+ if (free_functions.empty())
+ {
+ Trace("sygus-infer") << "...fail: no free function symbols." << std::endl;
+ return false;
+ }
+
+ Assert(!processed_assertions.empty());
+ // conjunction of the assertions
+ Trace("sygus-infer") << "Construct body..." << std::endl;
+ Node body;
+ if (processed_assertions.size() == 1)
+ {
+ body = processed_assertions[0];
+ }
+ else
+ {
+ body = nm->mkNode(AND, processed_assertions);
+ }
+
+ // for each free function symbol, make a bound variable of the same type
+ Trace("sygus-infer") << "Do free function substitution..." << std::endl;
+ std::vector<Node> ff_vars;
+ std::map<Node, Node> ff_var_to_ff;
+ for (const Node& ff : free_functions)
+ {
+ Node ffv = nm->mkBoundVar(ff.getType());
+ ff_vars.push_back(ffv);
+ Trace("sygus-infer") << " synth-fun: " << ff << " as " << ffv << std::endl;
+ ff_var_to_ff[ffv] = ff;
+ }
+ // substitute free functions -> variables
+ body = body.substitute(free_functions.begin(),
+ free_functions.end(),
+ ff_vars.begin(),
+ ff_vars.end());
+ Trace("sygus-infer-debug") << "...got : " << body << std::endl;
+
+ // quantify the body
+ Trace("sygus-infer") << "Make inner sygus conjecture..." << std::endl;
+ if (!qvars.empty())
+ {
+ Node bvl = nm->mkNode(BOUND_VAR_LIST, qvars);
+ body = nm->mkNode(EXISTS, bvl, body.negate());
+ }
+
+ // sygus attribute to mark the conjecture as a sygus conjecture
+ Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl;
+ Node sygusVar = nm->mkSkolem("sygus", nm->booleanType());
+ theory::SygusAttribute ca;
+ sygusVar.setAttribute(ca, true);
+ Node instAttr = nm->mkNode(INST_ATTRIBUTE, sygusVar);
+ Node instAttrList = nm->mkNode(INST_PATTERN_LIST, instAttr);
+
+ Node fbvl = nm->mkNode(BOUND_VAR_LIST, ff_vars);
+
+ body = nm->mkNode(FORALL, fbvl, body, instAttrList);
+
+ Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl;
+
+ // make a separate smt call
+ SmtEngine rrSygus(nm->toExprManager());
+ rrSygus.setLogic(smt::currentSmtEngine()->getLogicInfo());
+ rrSygus.assertFormula(body.toExpr());
+ Trace("sygus-infer") << "*** Check sat..." << std::endl;
+ Result r = rrSygus.checkSat();
+ Trace("sygus-infer") << "...result : " << r << std::endl;
+ if (r.asSatisfiabilityResult().isSat() != Result::UNSAT)
+ {
+ // failed, conjecture was infeasible
+ return false;
+ }
+ // get the synthesis solutions
+ std::map<Expr, Expr> synth_sols;
+ rrSygus.getSynthSolutions(synth_sols);
+
+ std::vector<Node> final_ff;
+ std::vector<Node> final_ff_sol;
+ for (std::map<Expr, Expr>::iterator it = synth_sols.begin();
+ it != synth_sols.end();
+ ++it)
+ {
+ Trace("sygus-infer") << " synth sol : " << it->first << " -> "
+ << it->second << std::endl;
+ Node ffv = Node::fromExpr(it->first);
+ std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv);
+ // all synthesis solutions should correspond to a variable we introduced
+ Assert(itffv != ff_var_to_ff.end());
+ if (itffv != ff_var_to_ff.end())
+ {
+ Node ff = itffv->second;
+ Node body = Node::fromExpr(it->second);
+ Trace("sygus-infer") << "Define " << ff << " as " << body << std::endl;
+ funs.push_back(ff);
+ sols.push_back(body);
+ }
+ }
+ return true;
+}
+
+} // namespace passes
+} // namespace preprocessing
+} // namespace CVC4
diff --git a/src/preprocessing/passes/sygus_inference.h b/src/preprocessing/passes/sygus_inference.h
new file mode 100644
index 000000000..5e7c6f7d0
--- /dev/null
+++ b/src/preprocessing/passes/sygus_inference.h
@@ -0,0 +1,71 @@
+/********************* */
+/*! \file sygus_inference.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief SygusInference
+ **/
+
+#ifndef __CVC4__PREPROCESSING__PASSES__SYGUS_INFERENCE_H_
+#define __CVC4__PREPROCESSING__PASSES__SYGUS_INFERENCE_H_
+
+#include <map>
+#include <string>
+#include <vector>
+#include "expr/node.h"
+
+#include "preprocessing/preprocessing_pass.h"
+#include "preprocessing/preprocessing_pass_context.h"
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+/** SygusInference
+ *
+ * A preprocessing utility that turns a set of (quantified) assertions into a
+ * single SyGuS conjecture. If this is possible, we solve for this single Sygus
+ * conjecture using a separate copy of the SMT engine. If sygus successfully
+ * solves the conjecture, we plug the synthesis solutions back into the original
+ * problem, thus obtaining a set of model substitutions under which the
+ * assertions should simplify to true.
+ */
+class SygusInference : public PreprocessingPass
+{
+ public:
+ SygusInference(PreprocessingPassContext* preprocContext);
+
+ protected:
+ /**
+ * Either replaces all uninterpreted functions in assertions by their
+ * interpretation in a sygus solution, or leaves the assertions unmodified.
+ */
+ PreprocessingPassResult applyInternal(
+ AssertionPipeline* assertionsToPreprocess) override;
+ /** solve sygus
+ *
+ * Returns true if we can recast the input problem assertions as a sygus
+ * problem and successfully solve it using a separate call to an SMT engine.
+ *
+ * We fail if either a sygus conjecture that corresponds to assertions cannot
+ * be inferred, or the sygus conjecture we infer is infeasible.
+ *
+ * If this function returns true, then we add all uninterpreted symbols s in
+ * assertions to funs and their corresponding solution to sols.
+ */
+ bool solveSygus(std::vector<Node>& assertions,
+ std::vector<Node>& funs,
+ std::vector<Node>& sols);
+};
+
+} // namespace passes
+} // namespace preprocessing
+} // namespace CVC4
+
+#endif /* __CVC4__PREPROCESSING__PASSES__SYGUS_INFERENCE_H_ */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback