diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2018-10-10 20:44:02 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-10 20:44:02 -0500 |
commit | 7d70b721f43157e01bc6166a822df79250df632a (patch) | |
tree | 6d92615b265f02e237717f49017d81b4646ae714 /src/preprocessing/passes | |
parent | aa84926fb81001cc86661dead2ac5b856dd45ba3 (diff) |
Synthesize rewrite rules from inputs (#2608)
Diffstat (limited to 'src/preprocessing/passes')
-rw-r--r-- | src/preprocessing/passes/synth_rew_rules.cpp | 468 | ||||
-rw-r--r-- | src/preprocessing/passes/synth_rew_rules.h | 39 |
2 files changed, 413 insertions, 94 deletions
diff --git a/src/preprocessing/passes/synth_rew_rules.cpp b/src/preprocessing/passes/synth_rew_rules.cpp index 2ff525285..14dea3908 100644 --- a/src/preprocessing/passes/synth_rew_rules.cpp +++ b/src/preprocessing/passes/synth_rew_rules.cpp @@ -18,143 +18,433 @@ #include "options/base_options.h" #include "options/quantifiers_options.h" #include "printer/printer.h" +#include "printer/sygus_print_callback.h" #include "theory/quantifiers/candidate_rewrite_database.h" +#include "theory/quantifiers/quantifiers_attributes.h" +#include "theory/quantifiers/sygus/sygus_grammar_cons.h" +#include "theory/quantifiers/term_canonize.h" +#include "theory/quantifiers/term_util.h" using namespace std; +using namespace CVC4::kind; namespace CVC4 { namespace preprocessing { namespace passes { -// Attribute for whether we have computed rewrite rules for a given term. -// Notice that this currently must be a global attribute, since if -// we've computed rewrites for a term, we should not compute rewrites for the -// same term in a subcall to another SmtEngine (for instance, when using -// "exact" equivalence checking). -struct SynthRrComputedAttributeId -{ -}; -typedef expr::Attribute<SynthRrComputedAttributeId, bool> - SynthRrComputedAttribute; - SynthRewRulesPass::SynthRewRulesPass(PreprocessingPassContext* preprocContext) : PreprocessingPass(preprocContext, "synth-rr"){}; PreprocessingPassResult SynthRewRulesPass::applyInternal( AssertionPipeline* assertionsToPreprocess) { - Trace("synth-rr-pass") << "Synthesize rewrite rules from assertions..." - << std::endl; + Trace("srs-input") << "Synthesize rewrite rules from assertions..." + << std::endl; std::vector<Node>& assertions = assertionsToPreprocess->ref(); + if (assertions.empty()) + { + return PreprocessingPassResult::NO_CONFLICT; + } - // compute the variables we will be sampling - std::vector<Node> vars; - unsigned nsamples = options::sygusSamples(); - - Options& nodeManagerOptions = NodeManager::currentNM()->getOptions(); - - // attribute to mark processed terms - SynthRrComputedAttribute srrca; + NodeManager* nm = NodeManager::currentNM(); // initialize the candidate rewrite - std::unique_ptr<theory::quantifiers::CandidateRewriteDatabaseGen> crdg; std::unordered_map<TNode, bool, TNodeHashFunction> visited; std::unordered_map<TNode, bool, TNodeHashFunction>::iterator it; std::vector<TNode> visit; - // two passes: the first collects the variables, the second registers the - // terms - for (unsigned r = 0; r < 2; r++) + // Get all usable terms from the input. A term is usable if it does not + // contain a quantified subterm + std::vector<Node> terms; + // all variables (free constants) appearing in the input + std::vector<Node> vars; + + // We will generate a fixed number of variables per type. These are the + // variables that appear as free variables in the rewrites we generate. + unsigned nvars = options::sygusRewSynthInputNVars(); + // must have at least one variable per type + nvars = nvars < 1 ? 1 : nvars; + std::map<TypeNode, std::vector<Node> > tvars; + std::vector<TypeNode> allVarTypes; + std::vector<Node> allVars; + unsigned varCounter = 0; + // standard constants for each type (e.g. true, false for Bool) + std::map<TypeNode, std::vector<Node> > consts; + + TNode cur; + Trace("srs-input") << "Collect terms in assertions..." << std::endl; + for (const Node& a : assertions) { - visited.clear(); - visit.clear(); - TNode cur; - for (const Node& a : assertions) + Trace("srs-input-debug") << "Assertion : " << a << std::endl; + visit.push_back(a); + do { - visit.push_back(a); - do + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + if (it == visited.end()) { - cur = visit.back(); - visit.pop_back(); - it = visited.find(cur); - // if already processed, ignore - if (cur.getAttribute(SynthRrComputedAttribute())) + Trace("srs-input-debug") << "...preprocess " << cur << std::endl; + visited[cur] = false; + Kind k = cur.getKind(); + bool isQuant = k == FORALL || k == EXISTS || k == LAMBDA || k == CHOICE; + // we recurse on this node if it is not a quantified formula + if (!isQuant) { - Trace("synth-rr-pass-debug") - << "...already processed " << cur << std::endl; + visit.push_back(cur); + for (const Node& cc : cur) + { + visit.push_back(cc); + } } - else if (it == visited.end()) + } + else if (!it->second) + { + Trace("srs-input-debug") << "...postprocess " << cur << std::endl; + // check if all of the children are valid + // this ensures we do not register terms that have e.g. quantified + // formulas as subterms + bool childrenValid = true; + for (const Node& cc : cur) { - Trace("synth-rr-pass-debug") << "...preprocess " << cur << std::endl; - visited[cur] = false; - Kind k = cur.getKind(); - bool isQuant = k == kind::FORALL || k == kind::EXISTS - || k == kind::LAMBDA || k == kind::CHOICE; - // we recurse on this node if it is not a quantified formula - if (!isQuant) + Assert(visited.find(cc) != visited.end()); + if (!visited[cc]) { - visit.push_back(cur); - for (const Node& cc : cur) - { - visit.push_back(cc); - } + childrenValid = false; } } - else if (!it->second) + if (childrenValid) { - Trace("synth-rr-pass-debug") << "...postprocess " << cur << std::endl; - // check if all of the children are valid - // this ensures we do not register terms that have e.g. quantified - // formulas as subterms - bool childrenValid = true; - for (const Node& cc : cur) + Trace("srs-input-debug") << "...children are valid" << std::endl; + Trace("srs-input-debug") << "Add term " << cur << std::endl; + if (cur.isVar()) { - Assert(visited.find(cc) != visited.end()); - if (!visited[cc]) - { - childrenValid = false; - } + vars.push_back(cur); } - if (childrenValid) + // register type information + TypeNode tn = cur.getType(); + if (tvars.find(tn) == tvars.end()) { - Trace("synth-rr-pass-debug") - << "...children are valid, check rewrites..." << std::endl; - if (r == 0) + // Only make one Boolean variable unless option is set. This ensures + // we do not compute purely Boolean rewrites by default. + unsigned useNVars = + (options::sygusRewSynthInputUseBool() || !tn.isBoolean()) + ? nvars + : 1; + for (unsigned i = 0; i < useNVars; i++) { - if (cur.isVar()) + // We must have a good name for these variables, these are + // the ones output in rewrite rules. We choose + // a,b,c,...,y,z,x1,x2,... + std::stringstream ssv; + if (varCounter < 26) { - vars.push_back(cur); + ssv << String::convertUnsignedIntToChar(varCounter + 32); } + else + { + ssv << "x" << (varCounter - 26); + } + varCounter++; + Node v = nm->mkBoundVar(ssv.str(), tn); + tvars[tn].push_back(v); + allVars.push_back(v); + allVarTypes.push_back(tn); } - else - { - Trace("synth-rr-pass-debug") << "Add term " << cur << std::endl; - // mark as processed - cur.setAttribute(srrca, true); - bool ret = crdg->addTerm(cur, *nodeManagerOptions.getOut()); - Trace("synth-rr-pass-debug") << "...return " << ret << std::endl; - // if we want only rewrites of minimal size terms, we would set - // childrenValid to false if ret is false here. - } + // also add the standard constants for this type + theory::quantifiers::CegGrammarConstructor::mkSygusConstantsForType( + tn, consts[tn]); + visit.insert(visit.end(), consts[tn].begin(), consts[tn].end()); } - visited[cur] = childrenValid; + terms.push_back(cur); } - } while (!visit.empty()); + visited[cur] = childrenValid; + } + } while (!visit.empty()); + } + Trace("srs-input") << "...finished." << std::endl; + + Trace("srs-input") << "Convert subterms to free variable form..." + << std::endl; + // Replace all free variables with bound variables. This ensures that + // we can perform term canonization on subterms. + std::vector<Node> vsubs; + for (const Node& v : vars) + { + TypeNode tnv = v.getType(); + Node vs = nm->mkBoundVar(tnv); + vsubs.push_back(vs); + } + if (!vars.empty()) + { + for (unsigned i = 0, nterms = terms.size(); i < nterms; i++) + { + terms[i] = terms[i].substitute( + vars.begin(), vars.end(), vsubs.begin(), vsubs.end()); } - if (r == 0) + } + Trace("srs-input") << "...finished." << std::endl; + + Trace("srs-input") << "Process " << terms.size() << " subterms..." + << std::endl; + // We've collected all terms in the input. We construct a sygus grammar in + // following which generates terms that correspond to abstractions of the + // terms in the input. + + // We map terms to a canonical (ordered variable) form. This ensures that + // we don't generate distinct grammar types for distinct alpha-equivalent + // terms, which would produce grammars of identical shape. + std::map<Node, Node> term_to_cterm; + std::map<Node, Node> cterm_to_term; + std::vector<Node> cterms; + // canonical terms for each type + std::map<TypeNode, std::vector<Node> > t_cterms; + theory::quantifiers::TermCanonize tcanon; + for (unsigned i = 0, nterms = terms.size(); i < nterms; i++) + { + Node n = terms[i]; + Node cn = tcanon.getCanonicalTerm(n); + term_to_cterm[n] = cn; + Trace("srs-input-debug") << "Canon : " << n << " -> " << cn << std::endl; + std::map<Node, Node>::iterator itc = cterm_to_term.find(cn); + if (itc == cterm_to_term.end()) { - Trace("synth-rr-pass-debug") - << "Initialize with " << nsamples - << " samples and variables : " << vars << std::endl; - crdg = std::unique_ptr<theory::quantifiers::CandidateRewriteDatabaseGen>( - new theory::quantifiers::CandidateRewriteDatabaseGen(vars, nsamples)); + cterm_to_term[cn] = n; + cterms.push_back(cn); + t_cterms[cn.getType()].push_back(cn); } } + Trace("srs-input") << "...finished." << std::endl; + // the sygus variable list + Node sygusVarList = nm->mkNode(BOUND_VAR_LIST, allVars); + Expr sygusVarListE = sygusVarList.toExpr(); + Trace("srs-input") << "Have " << cterms.size() << " canonical subterms." + << std::endl; + + Trace("srs-input") << "Construct unresolved types..." << std::endl; + // each canonical subterm corresponds to a grammar type + std::set<Type> unres; + std::vector<Datatype> datatypes; + // make unresolved types for each canonical term + std::map<Node, TypeNode> cterm_to_utype; + for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++) + { + Node ct = cterms[i]; + std::stringstream ss; + ss << "T" << i; + std::string tname = ss.str(); + TypeNode tnu = nm->mkSort(tname, ExprManager::SORT_FLAG_PLACEHOLDER); + cterm_to_utype[ct] = tnu; + unres.insert(tnu.toType()); + datatypes.push_back(Datatype(tname)); + } + Trace("srs-input") << "...finished." << std::endl; + + Trace("srs-input") << "Construct datatypes..." << std::endl; + for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++) + { + Node ct = cterms[i]; + Node t = cterm_to_term[ct]; + + // add the variables for the type + TypeNode ctt = ct.getType(); + Assert(tvars.find(ctt) != tvars.end()); + std::vector<Type> argList; + for (const Node& v : tvars[ctt]) + { + std::stringstream ssc; + ssc << "C_" << i << "_" << v; + datatypes[i].addSygusConstructor(v.toExpr(), ssc.str(), argList); + } + // add the constructor for the operator if it is not a variable + if (ct.getKind() != BOUND_VARIABLE) + { + Assert(!ct.isVar()); + Node op = ct.hasOperator() ? ct.getOperator() : ct; + // iterate over the original term + for (const Node& tc : t) + { + // map its arguments back to canonical + Assert(term_to_cterm.find(tc) != term_to_cterm.end()); + Node ctc = term_to_cterm[tc]; + Assert(cterm_to_utype.find(ctc) != cterm_to_utype.end()); + // get the type + argList.push_back(cterm_to_utype[ctc].toType()); + } + // check if we should chain + bool do_chain = false; + if (argList.size() > 2) + { + Kind k = NodeManager::operatorToKind(op); + do_chain = theory::quantifiers::TermUtil::isAssoc(k) + && theory::quantifiers::TermUtil::isComm(k); + // eliminate duplicate child types + std::vector<Type> argListTmp = argList; + argList.clear(); + std::map<Type, bool> hasArgType; + for (unsigned j = 0, size = argListTmp.size(); j < size; j++) + { + Type t = argListTmp[j]; + if (hasArgType.find(t) == hasArgType.end()) + { + hasArgType[t] = true; + argList.push_back(t); + } + } + } + if (do_chain) + { + // we make one type per child + // the operator of each constructor is a no-op + Node tbv = nm->mkBoundVar(ctt); + Expr lambdaOp = + nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv).toExpr(); + std::vector<Type> argListc; + // the following construction admits any number of repeated factors, + // so for instance, t1+t2+t3, we generate the grammar: + // T_{t1+t2+t3} -> + // +( T_{t1+t2+t3}, T_{t1+t2+t3} ) | T_{t1} | T_{t2} | T_{t3} + // where we write T_t to denote "the type that abstracts term t". + // Notice this construction allows to abstract subsets of the factors + // of t1+t2+t3. This is particularly helpful for terms t1+...+tn for + // large n, where we would like to consider binary applications of +. + for (unsigned j = 0, size = argList.size(); j < size; j++) + { + argListc.clear(); + argListc.push_back(argList[j]); + std::stringstream sscs; + sscs << "C_factor_" << i << "_" << j; + // ID function is not printed and does not count towards weight + datatypes[i].addSygusConstructor( + lambdaOp, + sscs.str(), + argListc, + printer::SygusEmptyPrintCallback::getEmptyPC(), + 0); + } + // recursive apply + Type recType = cterm_to_utype[ct].toType(); + argListc.clear(); + argListc.push_back(recType); + argListc.push_back(recType); + std::stringstream ssc; + ssc << "C_" << i << "_rec_" << op; + datatypes[i].addSygusConstructor(op.toExpr(), ssc.str(), argListc); + } + else + { + std::stringstream ssc; + ssc << "C_" << i << "_" << op; + datatypes[i].addSygusConstructor(op.toExpr(), ssc.str(), argList); + } + } + datatypes[i].setSygus(ctt.toType(), sygusVarListE, false, false); + } + Trace("srs-input") << "...finished." << std::endl; + + Trace("srs-input") << "Make mutual datatype types for subterms..." + << std::endl; + std::vector<DatatypeType> types = nm->toExprManager()->mkMutualDatatypeTypes( + datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER); + Trace("srs-input") << "...finished." << std::endl; + Assert(types.size() == unres.size()); + std::map<Node, DatatypeType> subtermTypes; + for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++) + { + subtermTypes[cterms[i]] = types[i]; + } + + Trace("srs-input") << "Construct the top-level types..." << std::endl; + // we now are ready to create the "top-level" types + std::map<TypeNode, TypeNode> tlGrammarTypes; + for (std::pair<const TypeNode, std::vector<Node> >& tcp : t_cterms) + { + TypeNode t = tcp.first; + std::stringstream ss; + ss << "T_" << t; + Datatype dttl(ss.str()); + Node tbv = nm->mkBoundVar(t); + // the operator of each constructor is a no-op + Expr lambdaOp = + nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv).toExpr(); + Trace("srs-input") << " We have " << tcp.second.size() + << " subterms of type " << t << std::endl; + for (unsigned i = 0, size = tcp.second.size(); i < size; i++) + { + Node n = tcp.second[i]; + // add constructor that encodes abstractions of this subterm + std::vector<Type> argList; + Assert(subtermTypes.find(n) != subtermTypes.end()); + argList.push_back(subtermTypes[n]); + std::stringstream ssc; + ssc << "Ctl_" << i; + // the no-op should not be printed, hence we pass an empty callback + dttl.addSygusConstructor(lambdaOp, + ssc.str(), + argList, + printer::SygusEmptyPrintCallback::getEmptyPC(), + 0); + Trace("srs-input-debug") + << "Grammar for subterm " << n << " is: " << std::endl; + Trace("srs-input-debug") << subtermTypes[n].getDatatype() << std::endl; + } + // set that this is a sygus datatype + dttl.setSygus(t.toType(), sygusVarListE, false, false); + DatatypeType tlt = nm->toExprManager()->mkDatatypeType( + dttl, ExprManager::DATATYPE_FLAG_PLACEHOLDER); + tlGrammarTypes[t] = TypeNode::fromType(tlt); + Trace("srs-input") << "Grammar is: " << std::endl; + Trace("srs-input") << tlt.getDatatype() << std::endl; + } + Trace("srs-input") << "...finished." << std::endl; + + // sygus attribute to mark the conjecture as a sygus conjecture + Trace("srs-input") << "Make sygus conjecture..." << std::endl; + Node iaVar = nm->mkSkolem("sygus", nm->booleanType()); + // the attribute to mark the conjecture as being a sygus conjecture + theory::SygusAttribute ca; + iaVar.setAttribute(ca, true); + Node instAttr = nm->mkNode(INST_ATTRIBUTE, iaVar); + Node instAttrList = nm->mkNode(INST_PATTERN_LIST, instAttr); + // we are "synthesizing" functions for each type of subterm + std::vector<Node> synthConj; + unsigned fCounter = 1; + theory::SygusSynthGrammarAttribute ssg; + for (std::pair<const TypeNode, TypeNode> ttp : tlGrammarTypes) + { + Node gvar = nm->mkBoundVar("sfproxy", ttp.second); + TypeNode ft = nm->mkFunctionType(allVarTypes, ttp.first); + // likewise, it is helpful if these have good names, we choose f1, f2, ... + std::stringstream ssf; + ssf << "f" << fCounter; + fCounter++; + Node sfun = nm->mkBoundVar(ssf.str(), ft); + // this marks that the grammar used for solutions for sfun is the type of + // gvar, which is the sygus datatype type constructed above. + sfun.setAttribute(ssg, gvar); + Node fvarBvl = nm->mkNode(BOUND_VAR_LIST, sfun); + + Node body = nm->mkConst(false); + body = nm->mkNode(FORALL, fvarBvl, body, instAttrList); + synthConj.push_back(body); + } + Node trueNode = nm->mkConst(true); + Node res = + synthConj.empty() + ? trueNode + : (synthConj.size() == 1 ? synthConj[0] : nm->mkNode(AND, synthConj)); + + Trace("srs-input") << "got : " << res << std::endl; + Trace("srs-input") << "...finished." << std::endl; + + assertionsToPreprocess->replace(0, res); + for (unsigned i = 1, size = assertionsToPreprocess->size(); i < size; ++i) + { + assertionsToPreprocess->replace(i, trueNode); + } - Trace("synth-rr-pass") << "...finished " << std::endl; return PreprocessingPassResult::NO_CONFLICT; } - } // namespace passes } // namespace preprocessing } // namespace CVC4 diff --git a/src/preprocessing/passes/synth_rew_rules.h b/src/preprocessing/passes/synth_rew_rules.h index cf0b491fb..2b05bbf00 100644 --- a/src/preprocessing/passes/synth_rew_rules.h +++ b/src/preprocessing/passes/synth_rew_rules.h @@ -24,12 +24,41 @@ namespace preprocessing { namespace passes { /** - * This class computes candidate rewrite rules of the form t1 = t2, where - * t1 and t2 are subterms of assertionsToPreprocess. It prints - * "candidate-rewrite" messages on the output stream of options. + * This class rewrites the input assertions into a sygus conjecture over a + * grammar whose terms are "abstractions" of the subterms of + * assertionsToPreprocess. In detail, assume our input was + * bvadd( bvlshr( bvadd( a, 4 ), 1 ), b ) = 1 + * This class constructs this grammar: + * A -> T1 | T2 | T3 | T4 | Tv + * T1 -> bvadd( T2, Tv ) | x | y + * T2 -> bvlshr( T3, T4 ) | x | y + * T3 -> bvadd( Tv, T5 ) | x | y + * T4 -> 1 | x | y + * T5 -> 4 | x | y + * Tv -> x | y + * Notice that this grammar generates all subterms of the input where leaves + * are replaced by the variables x and/or y. The number of variable constructors + * (x and y in this example) used in this construction is configurable via + * sygus-rr-synth-input-nvars. The default for this value is 3, the + * justification is that this covers most of the interesting rewrites while + * not being too inefficient. * - * In contrast to other preprocessing passes, this pass does not modify - * the set of assertions. + * Also notice that currently, this grammar construction admits terms that + * do not necessarily match any in the input. For example, the above grammar + * admits bvlshr( x, x ), which is not matchable with a subterm of the input. + * + * Notice that Booleans are treated specially unless the option + * --sygus-rr-synth-input-bool is enabled, since we do not by default want to + * generate purely propositional rewrites. In particular, we allocate only + * one Boolean variable (to ensure that no sygus type is non-empty). + * + * It then rewrites the input into the negated sygus conjecture + * forall x : ( BV_n x BV_n ) -> BV_n. false + * where x has the sygus grammar restriction A from above. This conjecture can + * then be processed using --sygus-rr-synth in the standard way, which will + * cause candidate rewrites to be printed on the output stream. If multiple + * types are present, then we generate a conjunction of multiple synthesis + * conjectures, which we enumerate terms for in parallel. */ class SynthRewRulesPass : public PreprocessingPass { |