diff options
Diffstat (limited to 'src/preprocessing/passes/synth_rew_rules.cpp')
-rw-r--r-- | src/preprocessing/passes/synth_rew_rules.cpp | 77 |
1 files changed, 41 insertions, 36 deletions
diff --git a/src/preprocessing/passes/synth_rew_rules.cpp b/src/preprocessing/passes/synth_rew_rules.cpp index 6d6e8fb27..47e64b2e4 100644 --- a/src/preprocessing/passes/synth_rew_rules.cpp +++ b/src/preprocessing/passes/synth_rew_rules.cpp @@ -15,6 +15,7 @@ #include "preprocessing/passes/synth_rew_rules.h" +#include "expr/sygus_datatype.h" #include "expr/term_canonize.h" #include "options/base_options.h" #include "options/quantifiers_options.h" @@ -236,14 +237,13 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( Trace("srs-input") << "...finished." << std::endl; // the sygus variable list Node sygusVarList = nm->mkNode(BOUND_VAR_LIST, allVars); - Expr sygusVarListE = sygusVarList.toExpr(); Trace("srs-input") << "Have " << cterms.size() << " canonical subterms." << std::endl; Trace("srs-input") << "Construct unresolved types..." << std::endl; // each canonical subterm corresponds to a grammar type std::set<Type> unres; - std::vector<Datatype> datatypes; + std::vector<SygusDatatype> sdts; // make unresolved types for each canonical term std::map<Node, TypeNode> cterm_to_utype; for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++) @@ -255,11 +255,11 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( TypeNode tnu = nm->mkSort(tname, ExprManager::SORT_FLAG_PLACEHOLDER); cterm_to_utype[ct] = tnu; unres.insert(tnu.toType()); - datatypes.push_back(Datatype(tname)); + sdts.push_back(SygusDatatype(tname)); } Trace("srs-input") << "...finished." << std::endl; - Trace("srs-input") << "Construct datatypes..." << std::endl; + Trace("srs-input") << "Construct sygus datatypes..." << std::endl; for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++) { Node ct = cterms[i]; @@ -268,7 +268,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( // add the variables for the type TypeNode ctt = ct.getType(); Assert(tvars.find(ctt) != tvars.end()); - std::vector<Type> argList; + std::vector<TypeNode> argList; // we add variable constructors if we are not Boolean, we are interested // in purely propositional rewrites (via the option), or this term is // a Boolean variable. @@ -279,7 +279,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( { std::stringstream ssc; ssc << "C_" << i << "_" << v; - datatypes[i].addSygusConstructor(v.toExpr(), ssc.str(), argList); + sdts[i].addConstructor(v, ssc.str(), argList); } } // add the constructor for the operator if it is not a variable @@ -295,7 +295,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( Node ctc = term_to_cterm[tc]; Assert(cterm_to_utype.find(ctc) != cterm_to_utype.end()); // get the type - argList.push_back(cterm_to_utype[ctc].toType()); + argList.push_back(cterm_to_utype[ctc]); } // check if we should chain bool do_chain = false; @@ -305,12 +305,12 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( do_chain = theory::quantifiers::TermUtil::isAssoc(k) && theory::quantifiers::TermUtil::isComm(k); // eliminate duplicate child types - std::vector<Type> argListTmp = argList; + std::vector<TypeNode> argListTmp = argList; argList.clear(); - std::map<Type, bool> hasArgType; + std::map<TypeNode, bool> hasArgType; for (unsigned j = 0, size = argListTmp.size(); j < size; j++) { - Type t = argListTmp[j]; + TypeNode t = argListTmp[j]; if (hasArgType.find(t) == hasArgType.end()) { hasArgType[t] = true; @@ -323,9 +323,9 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( // we make one type per child // the operator of each constructor is a no-op Node tbv = nm->mkBoundVar(ctt); - Expr lambdaOp = - nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv).toExpr(); - std::vector<Type> argListc; + Node lambdaOp = + nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv); + std::vector<TypeNode> argListc; // the following construction admits any number of repeated factors, // so for instance, t1+t2+t3, we generate the grammar: // T_{t1+t2+t3} -> @@ -341,44 +341,49 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( std::stringstream sscs; sscs << "C_factor_" << i << "_" << j; // ID function is not printed and does not count towards weight - datatypes[i].addSygusConstructor( - lambdaOp, - sscs.str(), - argListc, - printer::SygusEmptyPrintCallback::getEmptyPC(), - 0); + sdts[i].addConstructor(lambdaOp, + sscs.str(), + argListc, + printer::SygusEmptyPrintCallback::getEmptyPC(), + 0); } // recursive apply - Type recType = cterm_to_utype[ct].toType(); + TypeNode recType = cterm_to_utype[ct]; argListc.clear(); argListc.push_back(recType); argListc.push_back(recType); std::stringstream ssc; ssc << "C_" << i << "_rec_" << op; - datatypes[i].addSygusConstructor(op.toExpr(), ssc.str(), argListc); + sdts[i].addConstructor(op, ssc.str(), argListc); } else { std::stringstream ssc; ssc << "C_" << i << "_" << op; - datatypes[i].addSygusConstructor(op.toExpr(), ssc.str(), argList); + sdts[i].addConstructor(op, ssc.str(), argList); } } - Assert(datatypes[i].getNumConstructors() > 0); - datatypes[i].setSygus(ctt.toType(), sygusVarListE, false, false); + Assert(sdts[i].getNumConstructors() > 0); + sdts[i].initializeDatatype(ctt, sygusVarList, false, false); } Trace("srs-input") << "...finished." << std::endl; Trace("srs-input") << "Make mutual datatype types for subterms..." << std::endl; + // extract the datatypes + std::vector<Datatype> datatypes; + for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++) + { + datatypes.push_back(sdts[i].getDatatype()); + } std::vector<DatatypeType> types = nm->toExprManager()->mkMutualDatatypeTypes( datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER); Trace("srs-input") << "...finished." << std::endl; Assert(types.size() == unres.size()); - std::map<Node, DatatypeType> subtermTypes; + std::map<Node, TypeNode> subtermTypes; for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++) { - subtermTypes[cterms[i]] = types[i]; + subtermTypes[cterms[i]] = TypeNode::fromType(types[i]); } Trace("srs-input") << "Construct the top-level types..." << std::endl; @@ -389,34 +394,34 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( TypeNode t = tcp.first; std::stringstream ss; ss << "T_" << t; - Datatype dttl(ss.str()); + SygusDatatype sdttl(ss.str()); Node tbv = nm->mkBoundVar(t); // the operator of each constructor is a no-op - Expr lambdaOp = - nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv).toExpr(); + Node lambdaOp = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv); Trace("srs-input") << " We have " << tcp.second.size() << " subterms of type " << t << std::endl; for (unsigned i = 0, size = tcp.second.size(); i < size; i++) { Node n = tcp.second[i]; // add constructor that encodes abstractions of this subterm - std::vector<Type> argList; + std::vector<TypeNode> argList; Assert(subtermTypes.find(n) != subtermTypes.end()); argList.push_back(subtermTypes[n]); std::stringstream ssc; ssc << "Ctl_" << i; // the no-op should not be printed, hence we pass an empty callback - dttl.addSygusConstructor(lambdaOp, - ssc.str(), - argList, - printer::SygusEmptyPrintCallback::getEmptyPC(), - 0); + sdttl.addConstructor(lambdaOp, + ssc.str(), + argList, + printer::SygusEmptyPrintCallback::getEmptyPC(), + 0); Trace("srs-input-debug") << "Grammar for subterm " << n << " is: " << std::endl; Trace("srs-input-debug") << subtermTypes[n].getDatatype() << std::endl; } // set that this is a sygus datatype - dttl.setSygus(t.toType(), sygusVarListE, false, false); + sdttl.initializeDatatype(t, sygusVarList, false, false); + Datatype dttl = sdttl.getDatatype(); DatatypeType tlt = nm->toExprManager()->mkDatatypeType( dttl, ExprManager::DATATYPE_FLAG_PLACEHOLDER); tlGrammarTypes[t] = TypeNode::fromType(tlt); |