From c0a7095f13547ac0c0d4c92670000ca875b7c349 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 13 Dec 2019 11:07:16 -0600 Subject: Eliminate Expr-level calls in TypeNode (#3562) --- src/expr/type.cpp | 23 ++++++-- src/expr/type_node.cpp | 69 +++++++++++----------- src/expr/type_node.h | 30 ---------- src/preprocessing/passes/synth_rew_rules.cpp | 2 +- src/printer/cvc/cvc_printer.cpp | 19 ++++-- src/theory/datatypes/kinds | 12 ++-- src/theory/datatypes/theory_datatypes.cpp | 5 +- src/theory/datatypes/theory_datatypes_type_rules.h | 6 +- src/theory/datatypes/theory_datatypes_utils.cpp | 4 +- .../quantifiers/sygus/sygus_grammar_cons.cpp | 32 +++++----- .../quantifiers/sygus/sygus_grammar_norm.cpp | 2 +- 11 files changed, 101 insertions(+), 103 deletions(-) diff --git a/src/expr/type.cpp b/src/expr/type.cpp index 31f21667a..99fe73c22 100644 --- a/src/expr/type.cpp +++ b/src/expr/type.cpp @@ -98,7 +98,12 @@ bool Type::isFunctionLike() const Expr Type::mkGroundTerm() const { NodeManagerScope nms(d_nodeManager); - return d_typeNode->mkGroundTerm().toExpr(); + Expr ret = d_typeNode->mkGroundTerm().toExpr(); + if (ret.isNull()) + { + IllegalArgument(this, "Cannot construct ground term!"); + } + return ret; } Expr Type::mkGroundValue() const @@ -326,7 +331,8 @@ bool Type::isTuple() const { /** Is this a record type? */ bool Type::isRecord() const { NodeManagerScope nms(d_nodeManager); - return d_typeNode->isRecord(); + return d_typeNode->getKind() == kind::DATATYPE_TYPE + && DatatypeType(*this).getDatatype().isRecord(); } /** Is this a symbolic expression type? */ @@ -566,7 +572,14 @@ std::vector ConstructorType::getArgTypes() const { const Datatype& DatatypeType::getDatatype() const { NodeManagerScope nms(d_nodeManager); - return d_typeNode->getDatatype(); + Assert(isDatatype()); + if (d_typeNode->getKind() == kind::DATATYPE_TYPE) + { + DatatypeIndexConstant dic = d_typeNode->getConst(); + return d_nodeManager->getDatatypeForIndex(dic.getIndex()); + } + Assert(d_typeNode->getKind() == kind::PARAMETRIC_DATATYPE); + return DatatypeType((*d_typeNode)[0].toType()).getDatatype(); } bool DatatypeType::isParametric() const { @@ -636,7 +649,9 @@ std::vector DatatypeType::getTupleTypes() const { /** Get the description of the record type */ const Record& DatatypeType::getRecord() const { NodeManagerScope nms(d_nodeManager); - return d_typeNode->getRecord(); + Assert(isRecord()); + const Datatype& dt = getDatatype(); + return *(dt.getRecord()); } DatatypeType SelectorType::getDomain() const { diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index 8cc10b5b2..abca1e3ed 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -141,9 +141,8 @@ bool TypeNode::isFiniteInternal(bool usortFinite) if (isDatatype()) { TypeNode tn = *this; - const Datatype& dt = getDatatype(); - ret = usortFinite ? dt.isInterpretedFinite(tn.toType()) - : dt.isFinite(tn.toType()); + const DType& dt = getDType(); + ret = usortFinite ? dt.isInterpretedFinite(tn) : dt.isFinite(tn); } else if (isArray()) { @@ -250,12 +249,12 @@ bool TypeNode::isClosedEnumerable() setAttribute(IsClosedEnumerableAttr(), ret); setAttribute(IsClosedEnumerableComputedAttr(), true); TypeNode tn = *this; - const Datatype& dt = getDatatype(); + const DType& dt = getDType(); for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) { for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) { - TypeNode ctn = TypeNode::fromType(dt[i][j].getRangeType()); + TypeNode ctn = dt[i][j].getRangeType(); if (tn != ctn && !ctn.isClosedEnumerable()) { ret = false; @@ -351,12 +350,11 @@ TypeNode TypeNode::getBaseType() const { if (isSubtypeOf(realt)) { return realt; } else if (isParametricDatatype()) { - vector v; + std::vector v; for(size_t i = 1; i < getNumChildren(); ++i) { - v.push_back((*this)[i].getBaseType().toType()); + v.push_back((*this)[i].getBaseType()); } - TypeNode tn = TypeNode::fromType((*this)[0].getDatatype().getDatatypeType(v)); - return tn; + return (*this)[0].getDType().getTypeNode().instantiateParametricDatatype(v); } return *this; } @@ -387,39 +385,27 @@ std::vector TypeNode::getParamTypes() const { /** Is this a tuple type? */ bool TypeNode::isTuple() const { - return ( getKind() == kind::DATATYPE_TYPE && getDatatype().isTuple() ); -} - -/** Is this a record type? */ -bool TypeNode::isRecord() const { - return ( getKind() == kind::DATATYPE_TYPE && getDatatype().isRecord() ); + return (getKind() == kind::DATATYPE_TYPE && getDType().isTuple()); } size_t TypeNode::getTupleLength() const { Assert(isTuple()); - const Datatype& dt = getDatatype(); + const DType& dt = getDType(); Assert(dt.getNumConstructors() == 1); return dt[0].getNumArgs(); } vector TypeNode::getTupleTypes() const { Assert(isTuple()); - const Datatype& dt = getDatatype(); + const DType& dt = getDType(); Assert(dt.getNumConstructors() == 1); vector types; for(unsigned i = 0; i < dt[0].getNumArgs(); ++i) { - types.push_back(TypeNode::fromType(dt[0][i].getRangeType())); + types.push_back(dt[0][i].getRangeType()); } return types; } -const Record& TypeNode::getRecord() const { - Assert(isRecord()); - const Datatype & dt = getDatatype(); - return *(dt.getRecord()); - //return getAttribute(expr::DatatypeRecordAttr()).getConst(); -} - vector TypeNode::getSExprTypes() const { Assert(isSExpr()); vector types; @@ -437,11 +423,12 @@ bool TypeNode::isInstantiatedDatatype() const { if(getKind() != kind::PARAMETRIC_DATATYPE) { return false; } - const Datatype& dt = (*this)[0].getDatatype(); + const DType& dt = (*this)[0].getDType(); unsigned n = dt.getNumParameters(); Assert(n < getNumChildren()); for(unsigned i = 0; i < n; ++i) { - if(TypeNode::fromType(dt.getParameter(i)) == (*this)[i + 1]) { + if (dt.getParameter(i) == (*this)[i + 1]) + { return false; } } @@ -473,9 +460,9 @@ TypeNode TypeNode::instantiateSortConstructor( /** Is this an instantiated datatype parameter */ bool TypeNode::isParameterInstantiatedDatatype(unsigned n) const { AssertArgument(getKind() == kind::PARAMETRIC_DATATYPE, *this); - const Datatype& dt = (*this)[0].getDatatype(); + const DType& dt = (*this)[0].getDType(); AssertArgument(n < dt.getNumParameters(), *this); - return TypeNode::fromType(dt.getParameter(n)) != (*this)[n + 1]; + return dt.getParameter(n) != (*this)[n + 1]; } TypeNode TypeNode::leastCommonTypeNode(TypeNode t0, TypeNode t1){ @@ -601,13 +588,15 @@ Node TypeNode::getEnsureTypeCondition( Node n, TypeNode tn ) { } }else if( tn.isDatatype() && ntn.isDatatype() ){ if( tn.isTuple() && ntn.isTuple() ){ - const Datatype& dt1 = tn.getDatatype(); - const Datatype& dt2 = ntn.getDatatype(); + const DType& dt1 = tn.getDType(); + const DType& dt2 = ntn.getDType(); + NodeManager* nm = NodeManager::currentNM(); if( dt1[0].getNumArgs()==dt2[0].getNumArgs() ){ std::vector< Node > conds; for( unsigned i=0; imkNode( kind::APPLY_SELECTOR_TOTAL, Node::fromExpr( dt2[0][i].getSelector() ), n ); - Node etc = getEnsureTypeCondition( s, TypeNode::fromType( dt1[0][i].getRangeType() ) ); + Node s = nm->mkNode( + kind::APPLY_SELECTOR_TOTAL, dt2[0][i].getSelector(), n); + Node etc = getEnsureTypeCondition(s, dt1[0][i].getRangeType()); if( etc.isNull() ){ return Node::null(); }else{ @@ -615,11 +604,11 @@ Node TypeNode::getEnsureTypeCondition( Node n, TypeNode tn ) { } } if( conds.empty() ){ - return NodeManager::currentNM()->mkConst( true ); + return nm->mkConst(true); }else if( conds.size()==1 ){ return conds[0]; }else{ - return NodeManager::currentNM()->mkNode( kind::AND, conds ); + return nm->mkNode(kind::AND, conds); } } } @@ -640,6 +629,16 @@ bool TypeNode::isSortConstructor() const { return getKind() == kind::SORT_TYPE && hasAttribute(expr::SortArityAttr()); } +/** Is this a codatatype type */ +bool TypeNode::isCodatatype() const +{ + if (isDatatype()) + { + return getDType().isCodatatype(); + } + return false; +} + std::string TypeNode::toString() const { std::stringstream ss; OutputLanguage outlang = (this == &s_null) ? language::output::LANG_AUTO : options::outputLanguage(); diff --git a/src/expr/type_node.h b/src/expr/type_node.h index b1c4da026..017ffe3dd 100644 --- a/src/expr/type_node.h +++ b/src/expr/type_node.h @@ -597,12 +597,6 @@ public: /** Get the constituent types of a tuple type */ std::vector getTupleTypes() const; - /** Is this a record type? */ - bool isRecord() const; - - /** Get the description of the record type */ - const Record& getRecord() const; - /** Is this a symbolic expression type? */ bool isSExpr() const; @@ -659,9 +653,6 @@ public: /** Is this a tester type */ bool isTester() const; - /** Get the Datatype specification from a datatype type */ - const Datatype& getDatatype() const; - /** Get the internal Datatype specification from a datatype type */ const DType& getDType() const; @@ -1027,15 +1018,6 @@ inline bool TypeNode::isParametricDatatype() const { return getKind() == kind::PARAMETRIC_DATATYPE; } -/** Is this a codatatype type */ -inline bool TypeNode::isCodatatype() const { - if( isDatatype() ){ - return getDatatype().isCodatatype(); - }else{ - return false; - } -} - /** Is this a constructor type */ inline bool TypeNode::isConstructor() const { return getKind() == kind::CONSTRUCTOR_TYPE; @@ -1066,18 +1048,6 @@ inline bool TypeNode::isBitVector(unsigned size) const { ( getKind() == kind::BITVECTOR_TYPE && getConst() == size ); } -/** Get the datatype specification from a datatype type */ -inline const Datatype& TypeNode::getDatatype() const { - Assert(isDatatype()); - if( getKind() == kind::DATATYPE_TYPE ){ - DatatypeIndexConstant dic = getConst(); - return NodeManager::currentNM()->getDatatypeForIndex( dic.getIndex() ); - }else{ - Assert(getKind() == kind::PARAMETRIC_DATATYPE); - return (*this)[0].getDatatype(); - } -} - /** Get the exponent size of this floating-point type */ inline unsigned TypeNode::getFloatingPointExponentSize() const { Assert(isFloatingPoint()); diff --git a/src/preprocessing/passes/synth_rew_rules.cpp b/src/preprocessing/passes/synth_rew_rules.cpp index 47e64b2e4..f3ca65b79 100644 --- a/src/preprocessing/passes/synth_rew_rules.cpp +++ b/src/preprocessing/passes/synth_rew_rules.cpp @@ -417,7 +417,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( 0); Trace("srs-input-debug") << "Grammar for subterm " << n << " is: " << std::endl; - Trace("srs-input-debug") << subtermTypes[n].getDatatype() << std::endl; + Trace("srs-input-debug") << subtermTypes[n].getDType() << std::endl; } // set that this is a sygus datatype sdttl.initializeDatatype(t, sygusVarList, false, false); diff --git a/src/printer/cvc/cvc_printer.cpp b/src/printer/cvc/cvc_printer.cpp index b11ee77a7..22a491af5 100644 --- a/src/printer/cvc/cvc_printer.cpp +++ b/src/printer/cvc/cvc_printer.cpp @@ -373,8 +373,10 @@ void CvcPrinter::toStream( if( n.getNumChildren()==1 ){ out << "TUPLE"; } - }else if( t.isRecord() ){ - const Record& rec = t.getRecord(); + } + else if (t.toType().isRecord()) + { + const Record& rec = static_cast(t.toType()).getRecord(); out << "(# "; TNode::iterator i = n.begin(); bool first = true; @@ -389,7 +391,9 @@ void CvcPrinter::toStream( } out << " #)"; return; - }else{ + } + else + { toStream(op, n.getOperator(), depth, types, false); if (n.getNumChildren() == 0) { @@ -404,7 +408,12 @@ void CvcPrinter::toStream( case kind::APPLY_SELECTOR_TOTAL: { TypeNode t = n[0].getType(); Node opn = n.getOperator(); - if (t.isTuple() || t.isRecord()) + if (!t.isDatatype()) + { + toStream(op, opn, depth, types, false); + } + else if (t.isTuple() + || DatatypeType(t.toType()).isRecord()) { toStream(out, n[0], depth, types, true); out << '.'; @@ -434,7 +443,7 @@ void CvcPrinter::toStream( } break; case kind::APPLY_TESTER: { - Assert(!n.getType().isTuple() && !n.getType().isRecord()); + Assert(!n.getType().isTuple() && !n.getType().toType().isRecord()); op << "is_"; unsigned cindex = Datatype::indexOf(n.getOperator().toExpr()); const Datatype& dt = Datatype::datatypeOf(n.getOperator().toExpr()); diff --git a/src/theory/datatypes/kinds b/src/theory/datatypes/kinds index 22d13da0c..e3c09b635 100644 --- a/src/theory/datatypes/kinds +++ b/src/theory/datatypes/kinds @@ -44,11 +44,11 @@ constant DATATYPE_TYPE \ "expr/datatype.h" \ "a datatype type index" cardinality DATATYPE_TYPE \ - "%TYPE%.getDatatype().getCardinality(%TYPE%.toType())" \ + "%TYPE%.getDType().getCardinality(%TYPE%)" \ "expr/datatype.h" well-founded DATATYPE_TYPE \ - "%TYPE%.getDatatype().isWellFounded()" \ - "%TYPE%.getDatatype().mkGroundTerm(%TYPE%.toType())" \ + "%TYPE%.getDType().isWellFounded()" \ + "%TYPE%.getDType().mkGroundTerm(%TYPE%)" \ "expr/datatype.h" enumerator DATATYPE_TYPE \ @@ -57,11 +57,11 @@ enumerator DATATYPE_TYPE \ operator PARAMETRIC_DATATYPE 1: "parametric datatype" cardinality PARAMETRIC_DATATYPE \ - "DatatypeType(%TYPE%.toType()).getDatatype().getCardinality(%TYPE%.toType())" \ + "%TYPE%.getDType().getCardinality(%TYPE%)" \ "expr/datatype.h" well-founded PARAMETRIC_DATATYPE \ - "DatatypeType(%TYPE%.toType()).getDatatype().isWellFounded()" \ - "DatatypeType(%TYPE%.toType()).getDatatype().mkGroundTerm(%TYPE%.toType())" \ + "%TYPE%.getDType().isWellFounded()" \ + "%TYPE%.getDType().mkGroundTerm(%TYPE%)" \ "expr/datatype.h" enumerator PARAMETRIC_DATATYPE \ diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 5e071c85c..cf07bc0c1 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -631,8 +631,9 @@ Node TheoryDatatypes::expandDefinition(LogicRequest &logicRequest, Node n) { } else { - Assert(t.isRecord()); - const Record& record = t.getRecord(); + Assert(t.toType().isRecord()); + const Record& record = + DatatypeType(t.toType()).getRecord(); size = record.getNumFields(); updateIndex = record.getIndex( n.getOperator().getConst().getField()); diff --git a/src/theory/datatypes/theory_datatypes_type_rules.h b/src/theory/datatypes/theory_datatypes_type_rules.h index 97e67e7fa..e11ac67f1 100644 --- a/src/theory/datatypes/theory_datatypes_type_rules.h +++ b/src/theory/datatypes/theory_datatypes_type_rules.h @@ -287,11 +287,13 @@ struct RecordUpdateTypeRule { TypeNode recordType = n[0].getType(check); TypeNode newValue = n[1].getType(check); if (check) { - if (!recordType.isRecord()) { + if (!recordType.toType().isRecord()) + { throw TypeCheckingExceptionPrivate( n, "Record-update expression formed over non-record"); } - const Record& rec = recordType.getRecord(); + const Record& rec = + DatatypeType(recordType.toType()).getRecord(); if (!rec.contains(ru.getField())) { std::stringstream ss; ss << "Record-update field `" << ru.getField() diff --git a/src/theory/datatypes/theory_datatypes_utils.cpp b/src/theory/datatypes/theory_datatypes_utils.cpp index 2fe8a99fe..cb9ab1e30 100644 --- a/src/theory/datatypes/theory_datatypes_utils.cpp +++ b/src/theory/datatypes/theory_datatypes_utils.cpp @@ -489,7 +489,7 @@ Node sygusToBuiltinEval(Node n, const std::vector& args) if (it == visited.end()) { TypeNode tn = cur.getType(); - if (!tn.isDatatype() || !tn.getDatatype().isSygus()) + if (!tn.isDatatype() || !tn.getDType().isSygus()) { visited[cur] = cur; } @@ -502,7 +502,7 @@ Node sygusToBuiltinEval(Node n, const std::vector& args) { svarsInit = true; TypeNode tn = cur.getType(); - Node varList = Node::fromExpr(tn.getDatatype().getSygusVarList()); + Node varList = tn.getDType().getSygusVarList(); for (const Node& v : varList) { svars.push_back(v); diff --git a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp index 4727ab0b0..8c005bd3c 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp @@ -49,7 +49,7 @@ bool CegGrammarConstructor::hasSyntaxRestrictions(Node q) if (!gv.isNull()) { TypeNode tn = gv.getType(); - if (tn.isDatatype() && tn.getDatatype().isSygus()) + if (tn.isDatatype() && tn.getDType().isSygus()) { return true; } @@ -137,9 +137,9 @@ Node CegGrammarConstructor::process(Node q, std::stringstream ss; ss << sf; Node sfvl; - if (preGrammarType.isDatatype() && preGrammarType.getDatatype().isSygus()) + if (preGrammarType.isDatatype() && preGrammarType.getDType().isSygus()) { - sfvl = Node::fromExpr(preGrammarType.getDatatype().getSygusVarList()); + sfvl = preGrammarType.getDType().getSygusVarList(); tn = preGrammarType; }else{ sfvl = getSygusVarList(sf); @@ -260,7 +260,7 @@ Node CegGrammarConstructor::process(Node q, } tds->registerSygusType(tn); Assert(tn.isDatatype()); - const Datatype& dt = tn.getDatatype(); + const DType& dt = tn.getDType(); Assert(dt.isSygus()); if( !dt.getSygusAllowAll() ){ d_is_syntax_restricted = true; @@ -427,13 +427,13 @@ void CegGrammarConstructor::collectSygusGrammarTypesFor( Trace("sygus-grammar-def") << "...will make grammar for " << range << std::endl; types.push_back( range ); if( range.isDatatype() ){ - const Datatype& dt = range.getDatatype(); + const DType& dt = range.getDType(); for (unsigned i = 0, size = dt.getNumConstructors(); i < size; ++i) { for (unsigned j = 0, size_args = dt[i].getNumArgs(); j < size_args; ++j) { - TypeNode tn = TypeNode::fromType(dt[i][j].getRangeType()); + TypeNode tn = dt[i][j].getRangeType(); collectSygusGrammarTypesFor(tn, types); } } @@ -817,11 +817,11 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( else if (types[i].isDatatype()) { Trace("sygus-grammar-def") << "...add for constructors" << std::endl; - const Datatype& dt = types[i].getDatatype(); + const DType& dt = types[i].getDType(); for (unsigned k = 0, size_k = dt.getNumConstructors(); k < size_k; ++k) { Trace("sygus-grammar-def") << "...for " << dt[k].getName() << std::endl; - Node cop = Node::fromExpr(dt[k].getConstructor()); + Node cop = dt[k].getConstructor(); if (dt[k].getNumArgs() == 0) { // Nullary constructors are interpreted as terms, not operators. @@ -834,7 +834,7 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( { Trace("sygus-grammar-def") << "...for " << dt[k][j].getName() << std::endl; - TypeNode crange = TypeNode::fromType(dt[k][j].getRangeType()); + TypeNode crange = dt[k][j].getRangeType(); Assert(type_to_unres.find(crange) != type_to_unres.end()); cargsCons.push_back(type_to_unres[crange]); // add to the selector type the selector operator @@ -842,12 +842,12 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( Assert(std::find(types.begin(), types.end(), crange) != types.end()); unsigned i_selType = std::distance( types.begin(), std::find(types.begin(), types.end(), crange)); - TypeNode arg_type = TypeNode::fromType(dt[k][j].getType()); + TypeNode arg_type = dt[k][j].getType(); arg_type = arg_type.getSelectorDomainType(); Assert(type_to_unres.find(arg_type) != type_to_unres.end()); std::vector cargsSel; cargsSel.push_back(type_to_unres[arg_type]); - Node sel = Node::fromExpr(dt[k][j].getSelector()); + Node sel = dt[k][j].getSelector(); sdts[i_selType].addConstructor(sel, dt[k][j].getName(), cargsSel); } sdts[i].addConstructor(cop, dt[k].getName(), cargsCons); @@ -1175,14 +1175,16 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( { //add for testers Trace("sygus-grammar-def") << "...add for testers" << std::endl; - const Datatype& dt = types[i].getDatatype(); + const DType& dt = types[i].getDType(); std::vector cargsTester; cargsTester.push_back(unres_types[iuse]); for (unsigned k = 0, size_k = dt.getNumConstructors(); k < size_k; ++k) { - Trace("sygus-grammar-def") << "...for " << dt[k].getTesterName() << std::endl; - sdtBool.addConstructor( - dt[k].getTester(), dt[k].getTesterName(), cargsTester); + Trace("sygus-grammar-def") + << "...for " << dt[k].getTester() << std::endl; + std::stringstream sst; + sst << dt[k].getTester(); + sdtBool.addConstructor(dt[k].getTester(), sst.str(), cargsTester); } } } diff --git a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp index 019abc28a..c7c1d820f 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp @@ -599,7 +599,7 @@ TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn) return tn; } /* Collect all operators for normalization */ - const Datatype& dt = tn.getDatatype(); + const Datatype& dt = DatatypeType(tn.toType()).getDatatype(); if (!dt.isSygus()) { // datatype but not sygus datatype case -- cgit v1.2.3