diff options
Diffstat (limited to 'src/theory/quantifiers/quantifiers_rewriter.cpp')
-rw-r--r-- | src/theory/quantifiers/quantifiers_rewriter.cpp | 48 |
1 files changed, 25 insertions, 23 deletions
diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp index 0039ec845..8d65523e1 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.cpp +++ b/src/theory/quantifiers/quantifiers_rewriter.cpp @@ -14,9 +14,11 @@ #include "theory/quantifiers/quantifiers_rewriter.h" +#include "expr/dtype.h" #include "expr/node_algorithm.h" #include "options/quantifiers_options.h" #include "theory/arith/arith_msum.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/bv_inverter.h" #include "theory/quantifiers/ematching/trigger.h" #include "theory/quantifiers/quantifiers_attributes.h" @@ -308,8 +310,8 @@ void QuantifiersRewriter::computeDtTesterIteSplit( Node n, std::map< Node, Node Trace("quantifiers-rewrite-ite-debug") << "...condition already set " << itp->second << std::endl; computeDtTesterIteSplit( n[ itp->second==n[0] ? 1 : 2 ], pcons, ncons, conj ); }else{ - Expr testerExpr = n[0].getOperator().toExpr(); - int index = Datatype::indexOf( testerExpr ); + Node tester = n[0].getOperator(); + int index = datatypes::utils::indexOf(tester); std::map< int, Node >::iterator itn = ncons[x].find( index ); if( itn!=ncons[x].end() ){ Trace("quantifiers-rewrite-ite-debug") << "...condition negated " << itn->second << std::endl; @@ -328,6 +330,7 @@ void QuantifiersRewriter::computeDtTesterIteSplit( Node n, std::map< Node, Node } } }else{ + NodeManager* nm = NodeManager::currentNM(); Trace("quantifiers-rewrite-ite-debug") << "Return value : " << n << std::endl; std::vector< Node > children; children.push_back( n ); @@ -343,7 +346,7 @@ void QuantifiersRewriter::computeDtTesterIteSplit( Node n, std::map< Node, Node //only if we haven't settled on a positive tester if( std::find( vars.begin(), vars.end(), x )==vars.end() ){ //check if we have exhausted all options but one - const Datatype& dt = DatatypeType(x.getType().toType()).getDatatype(); + const DType& dt = x.getType().getDType(); std::vector< Node > nchildren; int pos_cons = -1; for( int i=0; i<(int)dt.getNumConstructors(); i++ ){ @@ -355,9 +358,8 @@ void QuantifiersRewriter::computeDtTesterIteSplit( Node n, std::map< Node, Node } } if( pos_cons>=0 ){ - const DatatypeConstructor& c = dt[pos_cons]; - Expr tester = c.getTester(); - children.push_back( NodeManager::currentNM()->mkNode( kind::APPLY_TESTER, Node::fromExpr( tester ), x ).negate() ); + Node tester = dt[pos_cons].getTester(); + children.push_back(nm->mkNode(APPLY_TESTER, tester, x).negate()); }else{ children.insert( children.end(), nchildren.begin(), nchildren.end() ); } @@ -454,20 +456,21 @@ void setEntailedCond( Node n, bool pol, std::map< Node, bool >& currCond, std::v } if( addEntailedCond( n, pol, currCond, new_cond, conflict ) ){ if( n.getKind()==APPLY_TESTER ){ - const Datatype& dt = Datatype::datatypeOf(n.getOperator().toExpr()); - unsigned index = Datatype::indexOf(n.getOperator().toExpr()); + NodeManager* nm = NodeManager::currentNM(); + const DType& dt = datatypes::utils::datatypeOf(n.getOperator()); + unsigned index = datatypes::utils::indexOf(n.getOperator()); Assert(dt.getNumConstructors() > 1); if( pol ){ for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ if( i!=index ){ - Node t = NodeManager::currentNM()->mkNode( APPLY_TESTER, Node::fromExpr( dt[i].getTester() ), n[0] ); + Node t = nm->mkNode(APPLY_TESTER, dt[i].getTester(), n[0]); addEntailedCond( t, false, currCond, new_cond, conflict ); } } }else{ if( dt.getNumConstructors()==2 ){ int oindex = 1-index; - Node t = NodeManager::currentNM()->mkNode( APPLY_TESTER, Node::fromExpr( dt[oindex].getTester() ), n[0] ); + Node t = nm->mkNode(APPLY_TESTER, dt[oindex].getTester(), n[0]); addEntailedCond( t, true, currCond, new_cond, conflict ); } } @@ -1011,16 +1014,16 @@ bool QuantifiersRewriter::getVarElimLit(Node lit, if (ita != args.end()) { vars.push_back(lit[0]); - Expr testerExpr = lit.getOperator().toExpr(); - int index = Datatype::indexOf(testerExpr); - const Datatype& dt = Datatype::datatypeOf(testerExpr); - const DatatypeConstructor& c = dt[index]; + Node tester = lit.getOperator(); + int index = datatypes::utils::indexOf(tester); + const DType& dt = datatypes::utils::datatypeOf(tester); + const DTypeConstructor& c = dt[index]; std::vector<Node> newChildren; - newChildren.push_back(Node::fromExpr(c.getConstructor())); + newChildren.push_back(c.getConstructor()); std::vector<Node> newVars; for (unsigned j = 0, nargs = c.getNumArgs(); j < nargs; j++) { - TypeNode tn = TypeNode::fromType(c[j].getRangeType()); + TypeNode tn = c[j].getRangeType(); Node v = nm->mkBoundVar(tn); newChildren.push_back(v); newVars.push_back(v); @@ -1081,8 +1084,8 @@ bool QuantifiersRewriter::getVarElimLit(Node lit, { Trace("var-elim-dt") << "Expand datatype variable based on : " << lit << std::endl; - Expr testerExpr = lit.getOperator().toExpr(); - unsigned index = Datatype::indexOf(testerExpr); + Node tester = lit.getOperator(); + unsigned index = datatypes::utils::indexOf(tester); Node s = datatypeExpand(index, lit[0], args); if (!s.isNull()) { @@ -1179,16 +1182,15 @@ Node QuantifiersRewriter::datatypeExpand(unsigned index, { return Node::null(); } - const Datatype& dt = - static_cast<DatatypeType>(v.getType().toType()).getDatatype(); + const DType& dt = v.getType().getDType(); Assert(index < dt.getNumConstructors()); - const DatatypeConstructor& c = dt[index]; + const DTypeConstructor& c = dt[index]; std::vector<Node> newChildren; - newChildren.push_back(Node::fromExpr(c.getConstructor())); + newChildren.push_back(c.getConstructor()); std::vector<Node> newVars; for (unsigned j = 0, nargs = c.getNumArgs(); j < nargs; j++) { - TypeNode tn = TypeNode::fromType(c.getArgType(j)); + TypeNode tn = c.getArgType(j); Node vn = NodeManager::currentNM()->mkBoundVar(tn); newChildren.push_back(vn); newVars.push_back(vn); |