summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2019-11-18 13:52:18 -0600
committerGitHub <noreply@github.com>2019-11-18 13:52:18 -0600
commit6fe7877d82721e453d5d928a8fe9dbad2099dac1 (patch)
tree3e4c6ef5d0b802db5f76d7cbac66a792f28c6dc7
parent357e81dfc393d9e2ea80f66cddc837564494a34c (diff)
Use standard sygus interface for abduction and rewrite rule synthesis (#3471)
-rw-r--r--src/preprocessing/passes/synth_rew_rules.cpp77
-rw-r--r--src/theory/quantifiers/sygus/sygus_abduct.cpp29
2 files changed, 58 insertions, 48 deletions
diff --git a/src/preprocessing/passes/synth_rew_rules.cpp b/src/preprocessing/passes/synth_rew_rules.cpp
index 6d6e8fb27..47e64b2e4 100644
--- a/src/preprocessing/passes/synth_rew_rules.cpp
+++ b/src/preprocessing/passes/synth_rew_rules.cpp
@@ -15,6 +15,7 @@
#include "preprocessing/passes/synth_rew_rules.h"
+#include "expr/sygus_datatype.h"
#include "expr/term_canonize.h"
#include "options/base_options.h"
#include "options/quantifiers_options.h"
@@ -236,14 +237,13 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
Trace("srs-input") << "...finished." << std::endl;
// the sygus variable list
Node sygusVarList = nm->mkNode(BOUND_VAR_LIST, allVars);
- Expr sygusVarListE = sygusVarList.toExpr();
Trace("srs-input") << "Have " << cterms.size() << " canonical subterms."
<< std::endl;
Trace("srs-input") << "Construct unresolved types..." << std::endl;
// each canonical subterm corresponds to a grammar type
std::set<Type> unres;
- std::vector<Datatype> datatypes;
+ std::vector<SygusDatatype> sdts;
// make unresolved types for each canonical term
std::map<Node, TypeNode> cterm_to_utype;
for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
@@ -255,11 +255,11 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
TypeNode tnu = nm->mkSort(tname, ExprManager::SORT_FLAG_PLACEHOLDER);
cterm_to_utype[ct] = tnu;
unres.insert(tnu.toType());
- datatypes.push_back(Datatype(tname));
+ sdts.push_back(SygusDatatype(tname));
}
Trace("srs-input") << "...finished." << std::endl;
- Trace("srs-input") << "Construct datatypes..." << std::endl;
+ Trace("srs-input") << "Construct sygus datatypes..." << std::endl;
for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
{
Node ct = cterms[i];
@@ -268,7 +268,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
// add the variables for the type
TypeNode ctt = ct.getType();
Assert(tvars.find(ctt) != tvars.end());
- std::vector<Type> argList;
+ std::vector<TypeNode> argList;
// we add variable constructors if we are not Boolean, we are interested
// in purely propositional rewrites (via the option), or this term is
// a Boolean variable.
@@ -279,7 +279,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
{
std::stringstream ssc;
ssc << "C_" << i << "_" << v;
- datatypes[i].addSygusConstructor(v.toExpr(), ssc.str(), argList);
+ sdts[i].addConstructor(v, ssc.str(), argList);
}
}
// add the constructor for the operator if it is not a variable
@@ -295,7 +295,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
Node ctc = term_to_cterm[tc];
Assert(cterm_to_utype.find(ctc) != cterm_to_utype.end());
// get the type
- argList.push_back(cterm_to_utype[ctc].toType());
+ argList.push_back(cterm_to_utype[ctc]);
}
// check if we should chain
bool do_chain = false;
@@ -305,12 +305,12 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
do_chain = theory::quantifiers::TermUtil::isAssoc(k)
&& theory::quantifiers::TermUtil::isComm(k);
// eliminate duplicate child types
- std::vector<Type> argListTmp = argList;
+ std::vector<TypeNode> argListTmp = argList;
argList.clear();
- std::map<Type, bool> hasArgType;
+ std::map<TypeNode, bool> hasArgType;
for (unsigned j = 0, size = argListTmp.size(); j < size; j++)
{
- Type t = argListTmp[j];
+ TypeNode t = argListTmp[j];
if (hasArgType.find(t) == hasArgType.end())
{
hasArgType[t] = true;
@@ -323,9 +323,9 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
// we make one type per child
// the operator of each constructor is a no-op
Node tbv = nm->mkBoundVar(ctt);
- Expr lambdaOp =
- nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv).toExpr();
- std::vector<Type> argListc;
+ Node lambdaOp =
+ nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
+ std::vector<TypeNode> argListc;
// the following construction admits any number of repeated factors,
// so for instance, t1+t2+t3, we generate the grammar:
// T_{t1+t2+t3} ->
@@ -341,44 +341,49 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
std::stringstream sscs;
sscs << "C_factor_" << i << "_" << j;
// ID function is not printed and does not count towards weight
- datatypes[i].addSygusConstructor(
- lambdaOp,
- sscs.str(),
- argListc,
- printer::SygusEmptyPrintCallback::getEmptyPC(),
- 0);
+ sdts[i].addConstructor(lambdaOp,
+ sscs.str(),
+ argListc,
+ printer::SygusEmptyPrintCallback::getEmptyPC(),
+ 0);
}
// recursive apply
- Type recType = cterm_to_utype[ct].toType();
+ TypeNode recType = cterm_to_utype[ct];
argListc.clear();
argListc.push_back(recType);
argListc.push_back(recType);
std::stringstream ssc;
ssc << "C_" << i << "_rec_" << op;
- datatypes[i].addSygusConstructor(op.toExpr(), ssc.str(), argListc);
+ sdts[i].addConstructor(op, ssc.str(), argListc);
}
else
{
std::stringstream ssc;
ssc << "C_" << i << "_" << op;
- datatypes[i].addSygusConstructor(op.toExpr(), ssc.str(), argList);
+ sdts[i].addConstructor(op, ssc.str(), argList);
}
}
- Assert(datatypes[i].getNumConstructors() > 0);
- datatypes[i].setSygus(ctt.toType(), sygusVarListE, false, false);
+ Assert(sdts[i].getNumConstructors() > 0);
+ sdts[i].initializeDatatype(ctt, sygusVarList, false, false);
}
Trace("srs-input") << "...finished." << std::endl;
Trace("srs-input") << "Make mutual datatype types for subterms..."
<< std::endl;
+ // extract the datatypes
+ std::vector<Datatype> datatypes;
+ for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
+ {
+ datatypes.push_back(sdts[i].getDatatype());
+ }
std::vector<DatatypeType> types = nm->toExprManager()->mkMutualDatatypeTypes(
datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
Trace("srs-input") << "...finished." << std::endl;
Assert(types.size() == unres.size());
- std::map<Node, DatatypeType> subtermTypes;
+ std::map<Node, TypeNode> subtermTypes;
for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
{
- subtermTypes[cterms[i]] = types[i];
+ subtermTypes[cterms[i]] = TypeNode::fromType(types[i]);
}
Trace("srs-input") << "Construct the top-level types..." << std::endl;
@@ -389,34 +394,34 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
TypeNode t = tcp.first;
std::stringstream ss;
ss << "T_" << t;
- Datatype dttl(ss.str());
+ SygusDatatype sdttl(ss.str());
Node tbv = nm->mkBoundVar(t);
// the operator of each constructor is a no-op
- Expr lambdaOp =
- nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv).toExpr();
+ Node lambdaOp = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
Trace("srs-input") << " We have " << tcp.second.size()
<< " subterms of type " << t << std::endl;
for (unsigned i = 0, size = tcp.second.size(); i < size; i++)
{
Node n = tcp.second[i];
// add constructor that encodes abstractions of this subterm
- std::vector<Type> argList;
+ std::vector<TypeNode> argList;
Assert(subtermTypes.find(n) != subtermTypes.end());
argList.push_back(subtermTypes[n]);
std::stringstream ssc;
ssc << "Ctl_" << i;
// the no-op should not be printed, hence we pass an empty callback
- dttl.addSygusConstructor(lambdaOp,
- ssc.str(),
- argList,
- printer::SygusEmptyPrintCallback::getEmptyPC(),
- 0);
+ sdttl.addConstructor(lambdaOp,
+ ssc.str(),
+ argList,
+ printer::SygusEmptyPrintCallback::getEmptyPC(),
+ 0);
Trace("srs-input-debug")
<< "Grammar for subterm " << n << " is: " << std::endl;
Trace("srs-input-debug") << subtermTypes[n].getDatatype() << std::endl;
}
// set that this is a sygus datatype
- dttl.setSygus(t.toType(), sygusVarListE, false, false);
+ sdttl.initializeDatatype(t, sygusVarList, false, false);
+ Datatype dttl = sdttl.getDatatype();
DatatypeType tlt = nm->toExprManager()->mkDatatypeType(
dttl, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
tlGrammarTypes[t] = TypeNode::fromType(tlt);
diff --git a/src/theory/quantifiers/sygus/sygus_abduct.cpp b/src/theory/quantifiers/sygus/sygus_abduct.cpp
index 529ef037f..0396aba86 100644
--- a/src/theory/quantifiers/sygus/sygus_abduct.cpp
+++ b/src/theory/quantifiers/sygus/sygus_abduct.cpp
@@ -17,6 +17,7 @@
#include "expr/datatype.h"
#include "expr/node_algorithm.h"
+#include "expr/sygus_datatype.h"
#include "printer/sygus_print_callback.h"
#include "theory/quantifiers/quantifiers_attributes.h"
#include "theory/quantifiers/quantifiers_rewriter.h"
@@ -86,7 +87,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
{
Assert(abdGType.isDatatype() && abdGType.getDatatype().isSygus());
// must convert all constructors to version with bound variables in "vars"
- std::vector<Datatype> datatypes;
+ std::vector<SygusDatatype> sdts;
std::set<Type> unres;
Trace("sygus-abduct-debug") << "Process abduction type:" << std::endl;
@@ -129,9 +130,9 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
const Datatype& dtc = curr.getDatatype();
std::stringstream ssdtn;
ssdtn << dtc.getName() << "_s";
- datatypes.push_back(Datatype(ssdtn.str()));
+ sdts.push_back(SygusDatatype(ssdtn.str()));
Trace("sygus-abduct-debug")
- << "Process datatype " << datatypes.back().getName() << "..."
+ << "Process datatype " << sdts.back().getName() << "..."
<< std::endl;
for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
{
@@ -141,7 +142,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
syms.begin(), syms.end(), varlist.begin(), varlist.end());
Trace("sygus-abduct-debug") << " Process constructor " << op << " / "
<< ops << "..." << std::endl;
- std::vector<Type> cargs;
+ std::vector<TypeNode> cargs;
for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
{
TypeNode argt = TypeNode::fromType(dtc[j].getArgType(k));
@@ -167,7 +168,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
}
Trace("sygus-abduct-debug")
<< " Arg #" << k << ": " << argtNew << std::endl;
- cargs.push_back(argtNew.toType());
+ cargs.push_back(argtNew);
}
// callback prints as the expression
std::shared_ptr<SygusPrintCallback> spc;
@@ -191,22 +192,26 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
ss << ops.getKind();
Trace("sygus-abduct-debug")
<< "Add constructor : " << ops << std::endl;
- datatypes.back().addSygusConstructor(
- ops.toExpr(), ss.str(), cargs, spc);
+ sdts.back().addConstructor(ops, ss.str(), cargs, spc);
}
Trace("sygus-abduct-debug")
<< "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl;
- datatypes.back().setSygus(dtc.getSygusType(),
- abvl.toExpr(),
- dtc.getSygusAllowConst(),
- dtc.getSygusAllowAll());
+ TypeNode stn = TypeNode::fromType(dtc.getSygusType());
+ sdts.back().initializeDatatype(
+ stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll());
}
dtToProcess.clear();
dtToProcess.insert(
dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end());
}
Trace("sygus-abduct-debug")
- << "Make " << datatypes.size() << " datatype types..." << std::endl;
+ << "Make " << sdts.size() << " datatype types..." << std::endl;
+ // extract the datatypes
+ std::vector<Datatype> datatypes;
+ for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
+ {
+ datatypes.push_back(sdts[i].getDatatype());
+ }
// make the datatype types
std::vector<DatatypeType> datatypeTypes =
nm->toExprManager()->mkMutualDatatypeTypes(
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback