summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/expr/node.h3
-rw-r--r--src/parser/smt2/Smt2.g176
-rw-r--r--src/printer/smt2/smt2_printer.cpp27
-rw-r--r--src/theory/datatypes/datatypes_rewriter.cpp90
-rw-r--r--src/theory/datatypes/kinds19
-rw-r--r--src/theory/datatypes/theory_datatypes_type_rules.h160
6 files changed, 376 insertions, 99 deletions
diff --git a/src/expr/node.h b/src/expr/node.h
index f0ee7a56c..b8a665f0c 100644
--- a/src/expr/node.h
+++ b/src/expr/node.h
@@ -468,7 +468,8 @@ public:
inline bool isClosure() const {
assertTNodeNotExpired();
return getKind() == kind::LAMBDA || getKind() == kind::FORALL
- || getKind() == kind::EXISTS || getKind() == kind::CHOICE;
+ || getKind() == kind::EXISTS || getKind() == kind::CHOICE
+ || getKind() == kind::MATCH_BIND_CASE;
}
/**
diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g
index 21e09317d..9a8232df9 100644
--- a/src/parser/smt2/Smt2.g
+++ b/src/parser/smt2/Smt2.g
@@ -1810,15 +1810,14 @@ termNonVariable[CVC4::Expr& expr, CVC4::Expr& expr2]
std::string attr;
Expr attexpr;
std::vector<Expr> patexprs;
- std::vector<Expr> patconds;
+ std::vector<Expr> matchcases;
std::unordered_set<std::string> names;
std::vector< std::pair<std::string, Expr> > binders;
- int match_vindex = -1;
- std::vector<Type> match_ptypes;
Type type;
Type type2;
api::Term atomTerm;
ParseOp p;
+ std::vector<Type> argTypes;
}
: LPAREN_TOK quantOp[kind]
LPAREN_TOK sortedVarList[sortedVarNames] RPAREN_TOK
@@ -1912,107 +1911,92 @@ termNonVariable[CVC4::Expr& expr, CVC4::Expr& expr2]
}
LPAREN_TOK
(
- /* match cases */
- LPAREN_TOK INDEX_TOK term[f, f2] {
- if( match_vindex==-1 ){
- match_vindex = (int)patexprs.size();
+ // case with non-nullary pattern
+ LPAREN_TOK LPAREN_TOK term[f, f2] {
+ args.clear();
+ PARSER_STATE->pushScope(true);
+ // f should be a constructor
+ type = f.getType();
+ Debug("parser-dt") << "Pattern head : " << f << " " << type << std::endl;
+ if (!type.isConstructor())
+ {
+ PARSER_STATE->parseError("Pattern must be application of a constructor or a variable.");
+ }
+ if (Datatype::datatypeOf(f).isParametric())
+ {
+ type = Datatype::datatypeOf(f)[Datatype::indexOf(f)].getSpecializedConstructorType(expr.getType());
+ }
+ argTypes = static_cast<ConstructorType>(type).getArgTypes();
+ }
+ // arguments of the pattern
+ ( symbol[name,CHECK_NONE,SYM_VARIABLE] {
+ if (args.size() >= argTypes.size())
+ {
+ PARSER_STATE->parseError("Too many arguments for pattern.");
+ }
+ //make of proper type
+ Expr arg = PARSER_STATE->mkBoundVar(name, argTypes[args.size()]);
+ args.push_back( arg );
}
- patexprs.push_back( f );
- patconds.push_back(MK_CONST(bool(true)));
+ )*
+ RPAREN_TOK term[f3, f2] {
+ // make the match case
+ std::vector<Expr> cargs;
+ cargs.push_back(f);
+ cargs.insert(cargs.end(),args.begin(),args.end());
+ Expr c = MK_EXPR(kind::APPLY_CONSTRUCTOR,cargs);
+ Expr bvl = MK_EXPR(kind::BOUND_VAR_LIST,args);
+ Expr mc = MK_EXPR(kind::MATCH_BIND_CASE, bvl, c, f3);
+ matchcases.push_back(mc);
+ // now, pop the scope
+ PARSER_STATE->popScope();
}
RPAREN_TOK
- | LPAREN_TOK LPAREN_TOK term[f, f2] {
- args.clear();
- PARSER_STATE->pushScope(true);
- //f should be a constructor
- type = f.getType();
- Debug("parser-dt") << "Pattern head : " << f << " " << f.getType() << std::endl;
- if( !type.isConstructor() ){
- PARSER_STATE->parseError("Pattern must be application of a constructor or a variable.");
- }
- if( Datatype::datatypeOf(f).isParametric() ){
- type = Datatype::datatypeOf(f)[Datatype::indexOf(f)].getSpecializedConstructorType(expr.getType());
- }
- match_ptypes = ((ConstructorType)type).getArgTypes();
- }
- //arguments
- ( symbol[name,CHECK_NONE,SYM_VARIABLE] {
- if( args.size()>=match_ptypes.size() ){
- PARSER_STATE->parseError("Too many arguments for pattern.");
- }
- //make of proper type
- Expr arg = PARSER_STATE->mkBoundVar(name, match_ptypes[args.size()]);
- args.push_back( arg );
- }
- )*
- RPAREN_TOK
- term[f3, f2] {
- const DatatypeConstructor& dtc = Datatype::datatypeOf(f)[Datatype::indexOf(f)];
- if( args.size()!=dtc.getNumArgs() ){
- PARSER_STATE->parseError("Bad number of arguments for application of constructor in pattern.");
- }
- //FIXME: make MATCH a kind and make this a rewrite
- // build a lambda
- std::vector<Expr> largs;
- largs.push_back( MK_EXPR( CVC4::kind::BOUND_VAR_LIST, args ) );
- largs.push_back( f3 );
- std::vector< Expr > aargs;
- aargs.push_back( MK_EXPR( CVC4::kind::LAMBDA, largs ) );
- for( unsigned i=0; i<dtc.getNumArgs(); i++ ){
- //can apply total version since we will be guarded by ITE condition
- // however, we need to apply partial version since we don't have the internal selector available
- aargs.push_back( MK_EXPR( CVC4::kind::APPLY_SELECTOR, dtc[i].getSelector(), expr ) );
- }
- patexprs.push_back( MK_EXPR( CVC4::kind::APPLY_UF, aargs ) );
- patconds.push_back( MK_EXPR( CVC4::kind::APPLY_TESTER, dtc.getTester(), expr ) );
- }
- RPAREN_TOK
- { PARSER_STATE->popScope(); }
- | LPAREN_TOK symbol[name,CHECK_DECLARED,SYM_VARIABLE] {
- f = PARSER_STATE->getVariable(name);
- type = f.getType();
- if( !type.isConstructor() || !((ConstructorType)type).getArgTypes().empty() ){
- PARSER_STATE->parseError("Must apply constructors of arity greater than 0 to arguments in pattern.");
- }
- }
- term[f3, f2] {
- const DatatypeConstructor& dtc = Datatype::datatypeOf(f)[Datatype::indexOf(f)];
- patexprs.push_back( f3 );
- patconds.push_back( MK_EXPR( CVC4::kind::APPLY_TESTER, dtc.getTester(), expr ) );
- }
- RPAREN_TOK
- )+
- RPAREN_TOK RPAREN_TOK {
- if( match_vindex==-1 ){
- const Datatype& dt = ((DatatypeType)expr.getType()).getDatatype();
- std::map< unsigned, bool > processed;
- unsigned count = 0;
- //ensure that all datatype constructors are matched (to ensure exhaustiveness)
- for( unsigned i=0; i<patconds.size(); i++ ){
- unsigned curr_index = Datatype::indexOf(patconds[i].getOperator());
- if( curr_index<0 && curr_index>=dt.getNumConstructors() ){
- PARSER_STATE->parseError("Pattern is not legal for the head of a match.");
+ // case with nullary or variable pattern
+ | LPAREN_TOK symbol[name,CHECK_NONE,SYM_VARIABLE] {
+ if (PARSER_STATE->isDeclared(name,SYM_VARIABLE))
+ {
+ f = PARSER_STATE->getVariable(name);
+ type = f.getType();
+ if (!type.isConstructor() ||
+ !((ConstructorType)type).getArgTypes().empty())
+ {
+ PARSER_STATE->parseError("Must apply constructors of arity greater than 0 to arguments in pattern.");
+ }
+ // make nullary constructor application
+ f = MK_EXPR(kind::APPLY_CONSTRUCTOR, f);
}
- if( processed.find( curr_index )==processed.end() ){
- processed[curr_index] = true;
- count++;
+ else
+ {
+ // it has the type of the head expr
+ f = PARSER_STATE->mkBoundVar(name, expr.getType());
}
}
- if( count!=dt.getNumConstructors() ){
- PARSER_STATE->parseError("Patterns are not exhaustive in a match construct.");
- }
- }
- //now, make the ITE
- int end_index = match_vindex==-1 ? patexprs.size()-1 : match_vindex;
- bool first_time = true;
- for( int index = end_index; index>=0; index-- ){
- if( first_time ){
- expr = patexprs[index];
- first_time = false;
- }else{
- expr = MK_EXPR( CVC4::kind::ITE, patconds[index], patexprs[index], expr );
+ term[f3, f2] {
+ Expr mc;
+ if (f.getKind() == kind::BOUND_VARIABLE)
+ {
+ Expr bvl = MK_EXPR(kind::BOUND_VAR_LIST, f);
+ mc = MK_EXPR(kind::MATCH_BIND_CASE, bvl, f, f3);
+ }
+ else
+ {
+ mc = MK_EXPR(kind::MATCH_CASE, f, f3);
+ }
+ matchcases.push_back(mc);
}
+ RPAREN_TOK
+ )+
+ RPAREN_TOK RPAREN_TOK {
+ //now, make the match
+ if (matchcases.empty())
+ {
+ PARSER_STATE->parseError("Must have at least one case in match.");
}
+ std::vector<Expr> mchildren;
+ mchildren.push_back(expr);
+ mchildren.insert(mchildren.end(), matchcases.begin(), matchcases.end());
+ expr = MK_EXPR(kind::MATCH, mchildren);
}
/* attributed expressions */
diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp
index df9bee981..013288880 100644
--- a/src/printer/smt2/smt2_printer.cpp
+++ b/src/printer/smt2/smt2_printer.cpp
@@ -540,7 +540,31 @@ void Smt2Printer::toStream(std::ostream& out,
}
return;
- case kind::LAMBDA:
+ case kind::LAMBDA: out << smtKindString(k, d_variant) << " "; break;
+ case kind::MATCH:
+ out << smtKindString(k, d_variant) << " ";
+ toStream(out, n[0], toDepth, types, TypeNode::null());
+ out << " (";
+ for (size_t i = 1, nchild = n.getNumChildren(); i < nchild; i++)
+ {
+ if (i > 1)
+ {
+ out << " ";
+ }
+ toStream(out, n[i], toDepth, types, TypeNode::null());
+ }
+ out << "))";
+ return;
+ case kind::MATCH_BIND_CASE:
+ // ignore the binder
+ toStream(out, n[1], toDepth, types, TypeNode::null());
+ out << " ";
+ toStream(out, n[2], toDepth, types, TypeNode::null());
+ out << ")";
+ return;
+ case kind::MATCH_CASE:
+ // do nothing
+ break;
case kind::CHOICE: out << smtKindString(k, d_variant) << " "; break;
// arith theory
@@ -1030,6 +1054,7 @@ static string smtKindString(Kind k, Variant v)
case kind::LAMBDA:
return "lambda";
+ case kind::MATCH: return "match";
case kind::CHOICE: return "choice";
// arith theory
diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp
index ac3bff21b..802dedcbd 100644
--- a/src/theory/datatypes/datatypes_rewriter.cpp
+++ b/src/theory/datatypes/datatypes_rewriter.cpp
@@ -148,6 +148,96 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
return RewriteResponse(REWRITE_AGAIN_FULL, ret);
}
}
+ else if (k == MATCH)
+ {
+ Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl;
+ Node h = in[0];
+ std::vector<Node> cases;
+ std::vector<Node> rets;
+ TypeNode t = h.getType();
+ const Datatype& dt = t.getDatatype();
+ for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++)
+ {
+ Node c = in[k];
+ Node cons;
+ Kind ck = c.getKind();
+ if (ck == MATCH_CASE)
+ {
+ Assert(c[0].getKind() == APPLY_CONSTRUCTOR);
+ cons = c[0].getOperator();
+ }
+ else if (ck == MATCH_BIND_CASE)
+ {
+ if (c[1].getKind() == APPLY_CONSTRUCTOR)
+ {
+ cons = c[1].getOperator();
+ }
+ }
+ else
+ {
+ AlwaysAssert(false);
+ }
+ size_t cindex = 0;
+ // cons is null in the default case
+ if (!cons.isNull())
+ {
+ cindex = Datatype::indexOf(cons.toExpr());
+ }
+ Node body;
+ if (ck == MATCH_CASE)
+ {
+ body = c[1];
+ }
+ else if (ck == MATCH_BIND_CASE)
+ {
+ std::vector<Node> vars;
+ std::vector<Node> subs;
+ if (cons.isNull())
+ {
+ Assert(c[1].getKind() == BOUND_VARIABLE);
+ vars.push_back(c[1]);
+ subs.push_back(h);
+ }
+ else
+ {
+ for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++)
+ {
+ vars.push_back(c[0][i]);
+ Node sc = nm->mkNode(
+ APPLY_SELECTOR_TOTAL,
+ Node::fromExpr(dt[cindex].getSelectorInternal(t.toType(), i)),
+ h);
+ subs.push_back(sc);
+ }
+ }
+ body =
+ c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+ }
+ if (!cons.isNull())
+ {
+ cases.push_back(mkTester(h, cindex, dt));
+ }
+ else
+ {
+ // variables have no constraints
+ cases.push_back(nm->mkConst(true));
+ }
+ rets.push_back(body);
+ }
+ Assert(!cases.empty());
+ // now make the ITE
+ std::reverse(cases.begin(), cases.end());
+ std::reverse(rets.begin(), rets.end());
+ Node ret = rets[0];
+ AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors());
+ for (unsigned i = 1, ncases = cases.size(); i < ncases; i++)
+ {
+ ret = nm->mkNode(ITE, cases[i], rets[i], ret);
+ }
+ Trace("dt-rewrite-match")
+ << "Rewrite match: " << in << " ... " << ret << std::endl;
+ return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+ }
if (k == kind::EQUAL)
{
diff --git a/src/theory/datatypes/kinds b/src/theory/datatypes/kinds
index a0b00bcb0..22d13da0c 100644
--- a/src/theory/datatypes/kinds
+++ b/src/theory/datatypes/kinds
@@ -119,4 +119,23 @@ typerule DT_SYGUS_BOUND ::CVC4::theory::datatypes::DtSygusBoundTypeRule
operator DT_SYGUS_EVAL 1: "datatypes sygus evaluation function"
typerule DT_SYGUS_EVAL ::CVC4::theory::datatypes::DtSyguEvalTypeRule
+
+# Kinds for match terms. For example, the match term
+# (match l (((cons h t) h) (nil 0)))
+# is represented by the AST
+# (MATCH l
+# (MATCH_BIND_CASE (BOUND_VAR_LIST h t) (cons h t) h)
+# (MATCH_CASE nil 0)
+# )
+# where notice that patterns with free variables use MATCH_BIND_CASE whereas
+# patterns with no free variables use MATCH_CASE.
+
+operator MATCH 2: "match construct"
+operator MATCH_CASE 2 "a match case"
+operator MATCH_BIND_CASE 3 "a match case with bound variables"
+
+typerule MATCH ::CVC4::theory::datatypes::MatchTypeRule
+typerule MATCH_CASE ::CVC4::theory::datatypes::MatchCaseTypeRule
+typerule MATCH_BIND_CASE ::CVC4::theory::datatypes::MatchBindCaseTypeRule
+
endtheory
diff --git a/src/theory/datatypes/theory_datatypes_type_rules.h b/src/theory/datatypes/theory_datatypes_type_rules.h
index 22ac074f0..c8c16f368 100644
--- a/src/theory/datatypes/theory_datatypes_type_rules.h
+++ b/src/theory/datatypes/theory_datatypes_type_rules.h
@@ -427,7 +427,165 @@ class DtSyguEvalTypeRule
}
return TypeNode::fromType(dt.getSygusType());
}
-}; /* class DtSygusBoundTypeRule */
+}; /* class DtSyguEvalTypeRule */
+
+class MatchTypeRule
+{
+ public:
+ static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
+ {
+ Assert(n.getKind() == kind::MATCH);
+
+ TypeNode retType;
+
+ TypeNode headType = n[0].getType(check);
+ if (!headType.isDatatype())
+ {
+ throw TypeCheckingExceptionPrivate(n, "expecting datatype head in match");
+ }
+ const Datatype& hdt = headType.getDatatype();
+
+ std::unordered_set<unsigned> patIndices;
+ bool patHasVariable = false;
+ // the type of a match case list is the least common type of its cases
+ for (unsigned i = 1, nchildren = n.getNumChildren(); i < nchildren; i++)
+ {
+ Node nc = n[i];
+ if (check)
+ {
+ Kind nck = nc.getKind();
+ std::unordered_set<Node, NodeHashFunction> bvs;
+ if (nck == kind::MATCH_BIND_CASE)
+ {
+ for (const Node& v : nc[0])
+ {
+ Assert(v.getKind() == kind::BOUND_VARIABLE);
+ bvs.insert(v);
+ }
+ }
+ else if (nck != kind::MATCH_CASE)
+ {
+ throw TypeCheckingExceptionPrivate(
+ n, "expected a match case in match expression");
+ }
+ // get the pattern type
+ unsigned pindex = nck == kind::MATCH_CASE ? 0 : 1;
+ TypeNode patType = nc[pindex].getType();
+ // should be caught in the above call
+ if (!patType.isDatatype())
+ {
+ throw TypeCheckingExceptionPrivate(
+ n, "expecting datatype pattern in match");
+ }
+ Kind ncpk = nc[pindex].getKind();
+ if (ncpk == kind::APPLY_CONSTRUCTOR)
+ {
+ for (const Node& arg : nc[pindex])
+ {
+ if (bvs.find(arg) == bvs.end())
+ {
+ throw TypeCheckingExceptionPrivate(
+ n,
+ "expecting distinct bound variable as argument to "
+ "constructor in pattern of match");
+ }
+ bvs.erase(arg);
+ }
+ unsigned ci = Datatype::indexOf(nc[pindex].getOperator().toExpr());
+ patIndices.insert(ci);
+ }
+ else if (ncpk == kind::BOUND_VARIABLE)
+ {
+ patHasVariable = true;
+ }
+ else
+ {
+ throw TypeCheckingExceptionPrivate(
+ n, "unexpected kind of term in pattern in match");
+ }
+ const Datatype& pdt = patType.getDatatype();
+ // compare datatypes instead of the types to catch parametric case,
+ // where the pattern has parametric type.
+ if (hdt != pdt)
+ {
+ std::stringstream ss;
+ ss << "pattern of a match case does not match the head type in match";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ }
+ TypeNode currType = nc.getType(check);
+ if (i == 1)
+ {
+ retType = currType;
+ }
+ else
+ {
+ retType = TypeNode::leastCommonTypeNode(retType, currType);
+ if (retType.isNull())
+ {
+ throw TypeCheckingExceptionPrivate(
+ n, "incomparable types in match case list");
+ }
+ }
+ }
+ if (check)
+ {
+ if (!patHasVariable && patIndices.size() < hdt.getNumConstructors())
+ {
+ throw TypeCheckingExceptionPrivate(
+ n, "cases for match term are not exhaustive");
+ }
+ }
+ return retType;
+ }
+}; /* class MatchTypeRule */
+
+class MatchCaseTypeRule
+{
+ public:
+ inline static TypeNode computeType(NodeManager* nodeManager,
+ TNode n,
+ bool check)
+ {
+ Assert(n.getKind() == kind::MATCH_CASE);
+ if (check)
+ {
+ TypeNode patType = n[0].getType(check);
+ if (!patType.isDatatype())
+ {
+ throw TypeCheckingExceptionPrivate(
+ n, "expecting datatype pattern in match case");
+ }
+ }
+ return n[1].getType(check);
+ }
+}; /* class MatchCaseTypeRule */
+
+class MatchBindCaseTypeRule
+{
+ public:
+ inline static TypeNode computeType(NodeManager* nodeManager,
+ TNode n,
+ bool check)
+ {
+ Assert(n.getKind() == kind::MATCH_BIND_CASE);
+ if (check)
+ {
+ if (n[0].getKind() != kind::BOUND_VAR_LIST)
+ {
+ throw TypeCheckingExceptionPrivate(
+ n, "expected a bound variable list in match bind case");
+ }
+ TypeNode patType = n[1].getType(check);
+ if (!patType.isDatatype())
+ {
+ throw TypeCheckingExceptionPrivate(
+ n, "expecting datatype pattern in match bind case");
+ }
+ }
+ return n[2].getType(check);
+ }
+}; /* class MatchBindCaseTypeRule */
} /* CVC4::theory::datatypes namespace */
} /* CVC4::theory namespace */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback