diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2020-10-29 21:51:18 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-29 21:51:18 -0500 |
commit | 21fd193bdaad1a952845326aa1c84654cfce1503 (patch) | |
tree | 5d7732c5442dc73120352eb25ed92af9806c0751 /src/api/cvc4cpp.cpp | |
parent | 3596632eef07dbe28ea4a4f166c18ad9fe26d4e0 (diff) |
Update api::Sort to use TypeNode instead of Type (#5363)
This is work towards removing the old API.
This makes TypeNode the backend for Sort instead of Type.
It also updates a unit test for methods isUninterpretedSortParameterized and getUninterpretedSortParamSorts whose implementation was previously buggy due to the implementation of Type-level SortType.
Diffstat (limited to 'src/api/cvc4cpp.cpp')
-rw-r--r-- | src/api/cvc4cpp.cpp | 409 |
1 files changed, 192 insertions, 217 deletions
diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index 507e270bb..e16d8c519 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -49,6 +49,7 @@ #include "expr/node_manager.h" #include "expr/sequence.h" #include "expr/type.h" +#include "expr/type_node.h" #include "options/main_options.h" #include "options/options.h" #include "options/smt_options.h" @@ -945,13 +946,26 @@ std::ostream& operator<<(std::ostream& out, const Result& r) /* -------------------------------------------------------------------------- */ Sort::Sort(const Solver* slv, const CVC4::Type& t) - : d_solver(slv), d_type(new CVC4::Type(t)) + : d_solver(slv), d_type(new CVC4::TypeNode(TypeNode::fromType(t))) +{ +} +Sort::Sort(const Solver* slv, const CVC4::TypeNode& t) + : d_solver(slv), d_type(new CVC4::TypeNode(t)) { } -Sort::Sort() : d_solver(nullptr), d_type(new CVC4::Type()) {} +Sort::Sort() : d_solver(nullptr), d_type(new CVC4::TypeNode()) {} -Sort::~Sort() {} +Sort::~Sort() +{ + if (d_solver != nullptr) + { + // Ensure that the correct node manager is in scope when the node is + // destroyed. + NodeManagerScope scope(d_solver->getNodeManager()); + d_type.reset(); + } +} /* Helpers */ /* -------------------------------------------------------------------------- */ @@ -996,7 +1010,7 @@ bool Sort::isDatatype() const { return d_type->isDatatype(); } bool Sort::isParametricDatatype() const { if (!d_type->isDatatype()) return false; - return TypeNode::fromType(*d_type).isParametricDatatype(); + return d_type->isParametricDatatype(); } bool Sort::isConstructor() const { return d_type->isConstructor(); } @@ -1015,7 +1029,7 @@ bool Sort::isArray() const { return d_type->isArray(); } bool Sort::isSet() const { return d_type->isSet(); } -bool Sort::isBag() const { return TypeNode::fromType(*d_type).isBag(); } +bool Sort::isBag() const { return d_type->isBag(); } bool Sort::isSequence() const { return d_type->isSequence(); } @@ -1038,7 +1052,7 @@ Datatype Sort::getDatatype() const { NodeManagerScope scope(d_solver->getNodeManager()); CVC4_API_CHECK(isDatatype()) << "Expected datatype sort."; - return Datatype(d_solver, TypeNode::fromType(*d_type).getDType()); + return Datatype(d_solver, d_type->getDType()); } Sort Sort::instantiate(const std::vector<Sort>& params) const @@ -1046,23 +1060,13 @@ Sort Sort::instantiate(const std::vector<Sort>& params) const NodeManagerScope scope(d_solver->getNodeManager()); CVC4_API_CHECK(isParametricDatatype() || isSortConstructor()) << "Expected parametric datatype or sort constructor sort."; - std::vector<TypeNode> tparams; - for (const Sort& s : params) - { - tparams.push_back(TypeNode::fromType(*s.d_type.get())); - } + std::vector<CVC4::TypeNode> tparams = sortVectorToTypeNodes(params); if (d_type->isDatatype()) { - return Sort(d_solver, - TypeNode::fromType(*d_type) - .instantiateParametricDatatype(tparams) - .toType()); + return Sort(d_solver, d_type->instantiateParametricDatatype(tparams)); } Assert(d_type->isSortConstructor()); - return Sort(d_solver, - d_solver->getNodeManager() - ->mkSort(TypeNode::fromType(*d_type), tparams) - .toType()); + return Sort(d_solver, d_solver->getNodeManager()->mkSort(*d_type, tparams)); } std::string Sort::toString() const @@ -1077,27 +1081,32 @@ std::string Sort::toString() const // !!! This is only temporarily available until the parser is fully migrated // to the new API. !!! -CVC4::Type Sort::getType(void) const { return *d_type; } +CVC4::Type Sort::getType(void) const +{ + if (d_type->isNull()) return Type(); + NodeManagerScope scope(d_solver->getNodeManager()); + return d_type->toType(); +} +const CVC4::TypeNode& Sort::getTypeNode(void) const { return *d_type; } /* Constructor sort ------------------------------------------------------- */ size_t Sort::getConstructorArity() const { CVC4_API_CHECK(isConstructor()) << "Not a constructor sort: " << (*this); - return ConstructorType(*d_type).getArity(); + return d_type->getNumChildren() - 1; } std::vector<Sort> Sort::getConstructorDomainSorts() const { CVC4_API_CHECK(isConstructor()) << "Not a constructor sort: " << (*this); - std::vector<CVC4::Type> types = ConstructorType(*d_type).getArgTypes(); - return typeVectorToSorts(d_solver, types); + return typeNodeVectorToSorts(d_solver, d_type->getArgTypes()); } Sort Sort::getConstructorCodomainSort() const { CVC4_API_CHECK(isConstructor()) << "Not a constructor sort: " << (*this); - return Sort(d_solver, ConstructorType(*d_type).getRangeType()); + return Sort(d_solver, d_type->getConstructorRangeType()); } /* Selector sort ------------------------------------------------------- */ @@ -1105,15 +1114,13 @@ Sort Sort::getConstructorCodomainSort() const Sort Sort::getSelectorDomainSort() const { CVC4_API_CHECK(isSelector()) << "Not a selector sort: " << (*this); - TypeNode typeNode = TypeNode::fromType(*d_type); - return Sort(d_solver, typeNode.getSelectorDomainType().toType()); + return Sort(d_solver, d_type->getSelectorDomainType()); } Sort Sort::getSelectorCodomainSort() const { CVC4_API_CHECK(isSelector()) << "Not a selector sort: " << (*this); - TypeNode typeNode = TypeNode::fromType(*d_type); - return Sort(d_solver, typeNode.getSelectorRangeType().toType()); + return Sort(d_solver, d_type->getSelectorRangeType()); } /* Tester sort ------------------------------------------------------- */ @@ -1121,8 +1128,7 @@ Sort Sort::getSelectorCodomainSort() const Sort Sort::getTesterDomainSort() const { CVC4_API_CHECK(isTester()) << "Not a tester sort: " << (*this); - TypeNode typeNode = TypeNode::fromType(*d_type); - return Sort(d_solver, typeNode.getTesterDomainType().toType()); + return Sort(d_solver, d_type->getTesterDomainType()); } Sort Sort::getTesterCodomainSort() const @@ -1136,20 +1142,19 @@ Sort Sort::getTesterCodomainSort() const size_t Sort::getFunctionArity() const { CVC4_API_CHECK(isFunction()) << "Not a function sort: " << (*this); - return FunctionType(*d_type).getArity(); + return d_type->getNumChildren() - 1; } std::vector<Sort> Sort::getFunctionDomainSorts() const { CVC4_API_CHECK(isFunction()) << "Not a function sort: " << (*this); - std::vector<CVC4::Type> types = FunctionType(*d_type).getArgTypes(); - return typeVectorToSorts(d_solver, types); + return typeNodeVectorToSorts(d_solver, d_type->getArgTypes()); } Sort Sort::getFunctionCodomainSort() const { CVC4_API_CHECK(isFunction()) << "Not a function sort" << (*this); - return Sort(d_solver, FunctionType(*d_type).getRangeType()); + return Sort(d_solver, d_type->getRangeType()); } /* Array sort ---------------------------------------------------------- */ @@ -1157,13 +1162,13 @@ Sort Sort::getFunctionCodomainSort() const Sort Sort::getArrayIndexSort() const { CVC4_API_CHECK(isArray()) << "Not an array sort."; - return Sort(d_solver, ArrayType(*d_type).getIndexType()); + return Sort(d_solver, d_type->getArrayIndexType()); } Sort Sort::getArrayElementSort() const { CVC4_API_CHECK(isArray()) << "Not an array sort."; - return Sort(d_solver, ArrayType(*d_type).getConstituentType()); + return Sort(d_solver, d_type->getArrayConstituentType()); } /* Set sort ------------------------------------------------------------ */ @@ -1171,7 +1176,7 @@ Sort Sort::getArrayElementSort() const Sort Sort::getSetElementSort() const { CVC4_API_CHECK(isSet()) << "Not a set sort."; - return Sort(d_solver, SetType(*d_type).getElementType()); + return Sort(d_solver, d_type->getSetElementType()); } /* Bag sort ------------------------------------------------------------ */ @@ -1179,9 +1184,7 @@ Sort Sort::getSetElementSort() const Sort Sort::getBagElementSort() const { CVC4_API_CHECK(isBag()) << "Not a bag sort."; - TypeNode typeNode = TypeNode::fromType(*d_type); - Type type = typeNode.getBagElementType().toType(); - return Sort(d_solver, type); + return Sort(d_solver, d_type->getBagElementType()); } /* Sequence sort ------------------------------------------------------- */ @@ -1189,7 +1192,7 @@ Sort Sort::getBagElementSort() const Sort Sort::getSequenceElementSort() const { CVC4_API_CHECK(isSequence()) << "Not a sequence sort."; - return Sort(d_solver, SequenceType(*d_type).getElementType()); + return Sort(d_solver, d_type->getSequenceElementType()); } /* Uninterpreted sort -------------------------------------------------- */ @@ -1197,20 +1200,28 @@ Sort Sort::getSequenceElementSort() const std::string Sort::getUninterpretedSortName() const { CVC4_API_CHECK(isUninterpretedSort()) << "Not an uninterpreted sort."; - return SortType(*d_type).getName(); + return d_type->getName(); } bool Sort::isUninterpretedSortParameterized() const { CVC4_API_CHECK(isUninterpretedSort()) << "Not an uninterpreted sort."; - return SortType(*d_type).isParameterized(); + // This method is not implemented in the NodeManager, since whether a + // uninterpreted sort is parametrized is irrelevant for solving. + return d_type->getNumChildren() > 0; } std::vector<Sort> Sort::getUninterpretedSortParamSorts() const { CVC4_API_CHECK(isUninterpretedSort()) << "Not an uninterpreted sort."; - std::vector<CVC4::Type> types = SortType(*d_type).getParamTypes(); - return typeVectorToSorts(d_solver, types); + // This method is not implemented in the NodeManager, since whether a + // uninterpreted sort is parametrized is irrelevant for solving. + std::vector<TypeNode> params; + for (size_t i = 0, nchildren = d_type->getNumChildren(); i < nchildren; i++) + { + params.push_back((*d_type)[i]); + } + return typeNodeVectorToSorts(d_solver, params); } /* Sort constructor sort ----------------------------------------------- */ @@ -1218,13 +1229,13 @@ std::vector<Sort> Sort::getUninterpretedSortParamSorts() const std::string Sort::getSortConstructorName() const { CVC4_API_CHECK(isSortConstructor()) << "Not a sort constructor sort."; - return SortConstructorType(*d_type).getName(); + return d_type->getName(); } size_t Sort::getSortConstructorArity() const { CVC4_API_CHECK(isSortConstructor()) << "Not a sort constructor sort."; - return SortConstructorType(*d_type).getArity(); + return d_type->getSortConstructorArity(); } /* Bit-vector sort ----------------------------------------------------- */ @@ -1232,7 +1243,7 @@ size_t Sort::getSortConstructorArity() const uint32_t Sort::getBVSize() const { CVC4_API_CHECK(isBitVector()) << "Not a bit-vector sort."; - return BitVectorType(*d_type).getSize(); + return d_type->getBitVectorSize(); } /* Floating-point sort ------------------------------------------------- */ @@ -1240,13 +1251,13 @@ uint32_t Sort::getBVSize() const uint32_t Sort::getFPExponentSize() const { CVC4_API_CHECK(isFloatingPoint()) << "Not a floating-point sort."; - return FloatingPointType(*d_type).getExponentSize(); + return d_type->getFloatingPointExponentSize(); } uint32_t Sort::getFPSignificandSize() const { CVC4_API_CHECK(isFloatingPoint()) << "Not a floating-point sort."; - return FloatingPointType(*d_type).getSignificandSize(); + return d_type->getFloatingPointSignificandSize(); } /* Datatype sort ------------------------------------------------------- */ @@ -1254,20 +1265,14 @@ uint32_t Sort::getFPSignificandSize() const std::vector<Sort> Sort::getDatatypeParamSorts() const { CVC4_API_CHECK(isParametricDatatype()) << "Not a parametric datatype sort."; - std::vector<CVC4::TypeNode> typeNodes = - TypeNode::fromType(*d_type).getParamTypes(); - std::vector<Sort> sorts; - for (size_t i = 0, tsize = typeNodes.size(); i < tsize; i++) - { - sorts.push_back(Sort(d_solver, typeNodes[i].toType())); - } - return sorts; + std::vector<CVC4::TypeNode> typeNodes = d_type->getParamTypes(); + return typeNodeVectorToSorts(d_solver, typeNodes); } size_t Sort::getDatatypeArity() const { CVC4_API_CHECK(isDatatype()) << "Not a datatype sort."; - return TypeNode::fromType(*d_type).getNumChildren() - 1; + return d_type->getNumChildren() - 1; } /* Tuple sort ---------------------------------------------------------- */ @@ -1275,20 +1280,14 @@ size_t Sort::getDatatypeArity() const size_t Sort::getTupleLength() const { CVC4_API_CHECK(isTuple()) << "Not a tuple sort."; - return TypeNode::fromType(*d_type).getTupleLength(); + return d_type->getTupleLength(); } std::vector<Sort> Sort::getTupleSorts() const { CVC4_API_CHECK(isTuple()) << "Not a tuple sort."; - std::vector<CVC4::TypeNode> typeNodes = - TypeNode::fromType(*d_type).getTupleTypes(); - std::vector<Sort> sorts; - for (size_t i = 0, tsize = typeNodes.size(); i < tsize; i++) - { - sorts.push_back(Sort(d_solver, typeNodes[i].toType())); - } - return sorts; + std::vector<CVC4::TypeNode> typeNodes = d_type->getTupleTypes(); + return typeNodeVectorToSorts(d_solver, typeNodes); } /* --------------------------------------------------------------------- */ @@ -1301,7 +1300,7 @@ std::ostream& operator<<(std::ostream& out, const Sort& s) size_t SortHashFunction::operator()(const Sort& s) const { - return TypeHashFunction()(*s.d_type); + return TypeNodeHashFunction()(*s.d_type); } /* -------------------------------------------------------------------------- */ @@ -1329,7 +1328,7 @@ Op::~Op() { if (d_solver != nullptr) { - // Ensure that the correct node manager is in scope when the node is + // Ensure that the correct node manager is in scope when the type node is // destroyed. NodeManagerScope scope(d_solver->getNodeManager()); d_node.reset(); @@ -1709,7 +1708,7 @@ Sort Term::getSort() const { CVC4_API_CHECK_NOT_NULL; NodeManagerScope scope(d_solver->getNodeManager()); - return Sort(d_solver, d_node->getType().toType()); + return Sort(d_solver, d_node->getType()); } Term Term::substitute(Term e, Term replacement) const @@ -2133,7 +2132,7 @@ void DatatypeConstructorDecl::addSelector(const std::string& name, Sort sort) NodeManagerScope scope(d_solver->getNodeManager()); CVC4_API_ARG_CHECK_EXPECTED(!sort.isNull(), sort) << "non-null range sort for selector"; - d_ctor->addArg(name, TypeNode::fromType(*sort.d_type)); + d_ctor->addArg(name, *sort.d_type); } void DatatypeConstructorDecl::addSelectorSelf(const std::string& name) @@ -2188,9 +2187,7 @@ DatatypeDecl::DatatypeDecl(const Solver* slv, bool isCoDatatype) : d_solver(slv), d_dtype(new CVC4::DType( - name, - std::vector<TypeNode>{TypeNode::fromType(*param.d_type)}, - isCoDatatype)) + name, std::vector<TypeNode>{*param.d_type}, isCoDatatype)) { } @@ -2200,11 +2197,7 @@ DatatypeDecl::DatatypeDecl(const Solver* slv, bool isCoDatatype) : d_solver(slv) { - std::vector<TypeNode> tparams; - for (const Sort& p : params) - { - tparams.push_back(TypeNode::fromType(*p.d_type)); - } + std::vector<TypeNode> tparams = sortVectorToTypeNodes(params); d_dtype = std::shared_ptr<CVC4::DType>( new CVC4::DType(name, tparams, isCoDatatype)); } @@ -2297,7 +2290,7 @@ Term DatatypeSelector::getSelectorTerm() const Sort DatatypeSelector::getRangeSort() const { - return Sort(d_solver, d_stor->getRangeType().toType()); + return Sort(d_solver, d_stor->getRangeType()); } std::string DatatypeSelector::toString() const @@ -2363,13 +2356,11 @@ Term DatatypeConstructor::getSpecializedConstructorTerm(Sort retSort) const CVC4_API_SOLVER_TRY_CATCH_BEGIN; NodeManager* nm = d_solver->getNodeManager(); - Node ret = nm->mkNode( - kind::APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType(d_ctor - ->getSpecializedConstructorType( - TypeNode::fromType(retSort.getType())) - .toType())), - d_ctor->getConstructor()); + Node ret = + nm->mkNode(kind::APPLY_TYPE_ASCRIPTION, + nm->mkConst(AscriptionType( + d_ctor->getSpecializedConstructorType(*retSort.d_type))), + d_ctor->getConstructor()); (void)ret.getType(true); /* kick off type checking */ // apply type ascription to the operator Term sctor = api::Term(d_solver, ret); @@ -2902,7 +2893,7 @@ Sort Grammar::resolve() // make the unresolved type, used for referencing the final version of // the ntsymbol's datatype ntsToUnres[ntsymbol] = - Sort(d_solver, d_solver->getExprManager()->mkSort(ntsymbol.toString())); + Sort(d_solver, d_solver->getNodeManager()->mkSort(ntsymbol.toString())); } std::vector<CVC4::DType> datatypes; @@ -2922,8 +2913,8 @@ Sort Grammar::resolve() if (d_allowVars.find(ntSym) != d_allowVars.cend()) { - addSygusConstructorVariables( - dtDecl, Sort(d_solver, ntSym.d_node->getType().toType())); + addSygusConstructorVariables(dtDecl, + Sort(d_solver, ntSym.d_node->getType())); } bool aci = d_allowConst.find(ntSym) != d_allowConst.end(); @@ -2938,7 +2929,7 @@ Sort Grammar::resolve() << " produced an empty rule list"; datatypes.push_back(*dtDecl.d_dtype); - unresTypes.insert(TypeNode::fromType(*ntsToUnres[ntSym].d_type)); + unresTypes.insert(*ntsToUnres[ntSym].d_type); } std::vector<TypeNode> datatypeTypes = @@ -2946,7 +2937,7 @@ Sort Grammar::resolve() datatypes, unresTypes, NodeManager::DATATYPE_FLAG_PLACEHOLDER); // return is the first datatype - return Sort(d_solver, datatypeTypes[0].toType()); + return Sort(d_solver, datatypeTypes[0]); } void Grammar::addSygusConstructorTerm( @@ -2978,11 +2969,7 @@ void Grammar::addSygusConstructorTerm( d_solver->getExprManager()->mkExpr( CVC4::kind::LAMBDA, {lbvl.d_node->toExpr(), op.d_node->toExpr()})); } - std::vector<TypeNode> cargst; - for (const Sort& s : cargs) - { - cargst.push_back(TypeNode::fromType(s.getType())); - } + std::vector<TypeNode> cargst = sortVectorToTypeNodes(cargs); dt.d_dtype->addSygusConstructor(*op.d_node, ssCName.str(), cargst); } @@ -3044,7 +3031,7 @@ void Grammar::addSygusConstructorVariables(DatatypeDecl& dt, Sort sort) const for (unsigned i = 0, size = d_sygusVars.size(); i < size; i++) { Term v = d_sygusVars[i]; - if (v.d_node->getType().toType() == *sort.d_type) + if (v.d_node->getType() == *sort.d_type) { std::stringstream ss; ss << v; @@ -3320,19 +3307,11 @@ std::vector<Sort> Solver::mkDatatypeSortsInternal( { CVC4_API_SOLVER_CHECK_SORT(sort); } - - std::set<TypeNode> utypes; - for (const Sort& s : unresolvedSorts) - { - utypes.insert(TypeNode::fromType(s.getType())); - } + + std::set<TypeNode> utypes = sortSetToTypeNodes(unresolvedSorts); std::vector<CVC4::TypeNode> dtypes = getNodeManager()->mkMutualDatatypeTypes(datatypes, utypes); - std::vector<Sort> retTypes; - for (CVC4::TypeNode t : dtypes) - { - retTypes.push_back(Sort(this, t.toType())); - } + std::vector<Sort> retTypes = typeNodeVectorToSorts(this, dtypes); return retTypes; CVC4_API_SOLVER_TRY_CATCH_END; @@ -3348,7 +3327,7 @@ std::vector<Type> Solver::sortVectorToTypes( for (const Sort& s : sorts) { CVC4_API_SOLVER_CHECK_SORT(s); - res.push_back(*s.d_type); + res.push_back(s.d_type->toType()); } return res; } @@ -3401,42 +3380,42 @@ bool Solver::supportsFloatingPoint() const Sort Solver::getNullSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, Type()); + return Sort(this, TypeNode()); CVC4_API_SOLVER_TRY_CATCH_END; } Sort Solver::getBooleanSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->booleanType()); + return Sort(this, getNodeManager()->booleanType()); CVC4_API_SOLVER_TRY_CATCH_END; } Sort Solver::getIntegerSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->integerType()); + return Sort(this, getNodeManager()->integerType()); CVC4_API_SOLVER_TRY_CATCH_END; } Sort Solver::getRealSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->realType()); + return Sort(this, getNodeManager()->realType()); CVC4_API_SOLVER_TRY_CATCH_END; } Sort Solver::getRegExpSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->regExpType()); + return Sort(this, getNodeManager()->regExpType()); CVC4_API_SOLVER_TRY_CATCH_END; } Sort Solver::getStringSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->stringType()); + return Sort(this, getNodeManager()->stringType()); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3445,7 +3424,7 @@ Sort Solver::getRoundingModeSort(void) const CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_CHECK(Configuration::isBuiltWithSymFPU()) << "Expected CVC4 to be compiled with SymFPU support"; - return Sort(this, d_exprMgr->roundingModeType()); + return Sort(this, getNodeManager()->roundingModeType()); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3461,8 +3440,8 @@ Sort Solver::mkArraySort(Sort indexSort, Sort elemSort) const CVC4_API_SOLVER_CHECK_SORT(indexSort); CVC4_API_SOLVER_CHECK_SORT(elemSort); - return Sort(this, - d_exprMgr->mkArrayType(*indexSort.d_type, *elemSort.d_type)); + return Sort( + this, getNodeManager()->mkArrayType(*indexSort.d_type, *elemSort.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3472,7 +3451,7 @@ Sort Solver::mkBitVectorSort(uint32_t size) const CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_EXPECTED(size > 0, size) << "size > 0"; - return Sort(this, d_exprMgr->mkBitVectorType(size)); + return Sort(this, getNodeManager()->mkBitVectorType(size)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3485,7 +3464,7 @@ Sort Solver::mkFloatingPointSort(uint32_t exp, uint32_t sig) const CVC4_API_ARG_CHECK_EXPECTED(exp > 0, exp) << "exponent size > 0"; CVC4_API_ARG_CHECK_EXPECTED(sig > 0, sig) << "significand size > 0"; - return Sort(this, d_exprMgr->mkFloatingPointType(exp, sig)); + return Sort(this, getNodeManager()->mkFloatingPointType(exp, sig)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3499,8 +3478,7 @@ Sort Solver::mkDatatypeSort(DatatypeDecl dtypedecl) const CVC4_API_ARG_CHECK_EXPECTED(dtypedecl.getNumConstructors() > 0, dtypedecl) << "a datatype declaration with at least one constructor"; - return Sort(this, - getNodeManager()->mkDatatypeType(*dtypedecl.d_dtype).toType()); + return Sort(this, getNodeManager()->mkDatatypeType(*dtypedecl.d_dtype)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3536,8 +3514,8 @@ Sort Solver::mkFunctionSort(Sort domain, Sort codomain) const << "first-class sort as codomain sort for function sort"; Assert(!codomain.isFunction()); /* A function sort is not first-class. */ - return Sort(this, - d_exprMgr->mkFunctionType(*domain.d_type, *codomain.d_type)); + return Sort( + this, getNodeManager()->mkFunctionType(*domain.d_type, *codomain.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3566,8 +3544,9 @@ Sort Solver::mkFunctionSort(const std::vector<Sort>& sorts, Sort codomain) const << "first-class sort as codomain sort for function sort"; Assert(!codomain.isFunction()); /* A function sort is not first-class. */ - std::vector<Type> argTypes = sortVectorToTypes(sorts); - return Sort(this, d_exprMgr->mkFunctionType(argTypes, *codomain.d_type)); + std::vector<TypeNode> argTypes = sortVectorToTypeNodes(sorts); + return Sort(this, + getNodeManager()->mkFunctionType(argTypes, *codomain.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3575,8 +3554,9 @@ Sort Solver::mkFunctionSort(const std::vector<Sort>& sorts, Sort codomain) const Sort Solver::mkParamSort(const std::string& symbol) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, - d_exprMgr->mkSort(symbol, ExprManager::SORT_FLAG_PLACEHOLDER)); + return Sort( + this, + getNodeManager()->mkSort(symbol, ExprManager::SORT_FLAG_PLACEHOLDER)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3597,9 +3577,9 @@ Sort Solver::mkPredicateSort(const std::vector<Sort>& sorts) const sorts[i].isFirstClass(), "parameter sort", sorts[i], i) << "first-class sort as parameter sort for predicate sort"; } - std::vector<Type> types = sortVectorToTypes(sorts); + std::vector<TypeNode> types = sortVectorToTypeNodes(sorts); - return Sort(this, d_exprMgr->mkPredicateType(types)); + return Sort(this, getNodeManager()->mkPredicateType(types)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3620,10 +3600,10 @@ Sort Solver::mkRecordSort( this == p.second.d_solver, "parameter sort", p.second, i) << "sort associated to this solver object"; i += 1; - f.emplace_back(p.first, *p.second.d_type); + f.emplace_back(p.first, p.second.d_type->toType()); } - return Sort(this, getNodeManager()->mkRecordType(Record(f)).toType()); + return Sort(this, getNodeManager()->mkRecordType(Record(f))); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3635,7 +3615,7 @@ Sort Solver::mkSetSort(Sort elemSort) const << "non-null element sort"; CVC4_API_SOLVER_CHECK_SORT(elemSort); - return Sort(this, d_exprMgr->mkSetType(*elemSort.d_type)); + return Sort(this, getNodeManager()->mkSetType(*elemSort.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3647,9 +3627,7 @@ Sort Solver::mkBagSort(Sort elemSort) const << "non-null element sort"; CVC4_API_SOLVER_CHECK_SORT(elemSort); - TypeNode typeNode = TypeNode::fromType(*elemSort.d_type); - Type type = getNodeManager()->mkBagType(typeNode).toType(); - return Sort(this, type); + return Sort(this, getNodeManager()->mkBagType(*elemSort.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3661,7 +3639,7 @@ Sort Solver::mkSequenceSort(Sort elemSort) const << "non-null element sort"; CVC4_API_SOLVER_CHECK_SORT(elemSort); - return Sort(this, d_exprMgr->mkSequenceType(*elemSort.d_type)); + return Sort(this, getNodeManager()->mkSequenceType(*elemSort.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3669,7 +3647,7 @@ Sort Solver::mkSequenceSort(Sort elemSort) const Sort Solver::mkUninterpretedSort(const std::string& symbol) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->mkSort(symbol)); + return Sort(this, getNodeManager()->mkSort(symbol)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3679,7 +3657,7 @@ Sort Solver::mkSortConstructorSort(const std::string& symbol, CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_EXPECTED(arity > 0, arity) << "an arity > 0"; - return Sort(this, d_exprMgr->mkSortConstructor(symbol, arity)); + return Sort(this, getNodeManager()->mkSortConstructor(symbol, arity)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3699,12 +3677,8 @@ Sort Solver::mkTupleSort(const std::vector<Sort>& sorts) const !sorts[i].isFunctionLike(), "parameter sort", sorts[i], i) << "non-function-like sort as parameter sort for tuple sort"; } - std::vector<TypeNode> typeNodes; - for (const Sort& s : sorts) - { - typeNodes.push_back(TypeNode::fromType(*s.d_type)); - } - return Sort(this, getNodeManager()->mkTupleType(typeNodes).toType()); + std::vector<TypeNode> typeNodes = sortVectorToTypeNodes(sorts); + return Sort(this, getNodeManager()->mkTupleType(typeNodes)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3840,8 +3814,7 @@ Term Solver::mkEmptySet(Sort s) const CVC4_API_ARG_CHECK_EXPECTED(s.isNull() || this == s.d_solver, s) << "set sort associated to this solver object"; - return mkValHelper<CVC4::EmptySet>( - CVC4::EmptySet(TypeNode::fromType(*s.d_type))); + return mkValHelper<CVC4::EmptySet>(CVC4::EmptySet(*s.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3855,8 +3828,7 @@ Term Solver::mkSingleton(Sort s, Term t) const CVC4_API_SOLVER_CHECK_TERM(t); checkMkTerm(SINGLETON, 1); - TypeNode typeNode = TypeNode::fromType(*s.d_type); - Node res = getNodeManager()->mkSingleton(typeNode, *t.d_node); + Node res = getNodeManager()->mkSingleton(*s.d_type, *t.d_node); (void)res.getType(true); /* kick off type checking */ return Term(this, res); @@ -3872,8 +3844,7 @@ Term Solver::mkEmptyBag(Sort s) const CVC4_API_ARG_CHECK_EXPECTED(s.isNull() || this == s.d_solver, s) << "bag sort associated to this solver object"; - return mkValHelper<CVC4::EmptyBag>( - CVC4::EmptyBag(TypeNode::fromType(*s.d_type))); + return mkValHelper<CVC4::EmptyBag>(CVC4::EmptyBag(*s.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3884,7 +3855,8 @@ Term Solver::mkSepNil(Sort sort) const CVC4_API_ARG_CHECK_EXPECTED(!sort.isNull(), sort) << "non-null sort"; CVC4_API_SOLVER_CHECK_SORT(sort); - Expr res = d_exprMgr->mkNullaryOperator(*sort.d_type, CVC4::kind::SEP_NIL); + Node res = + getNodeManager()->mkNullaryOperator(*sort.d_type, CVC4::kind::SEP_NIL); (void)res.getType(true); /* kick off type checking */ return Term(this, res); @@ -3926,8 +3898,7 @@ Term Solver::mkEmptySequence(Sort sort) const CVC4_API_SOLVER_CHECK_SORT(sort); std::vector<Node> seq; - Expr res = - d_exprMgr->mkConst(Sequence(TypeNode::fromType(*sort.d_type), seq)); + Expr res = d_exprMgr->mkConst(Sequence(*sort.d_type, seq)); return Term(this, res); CVC4_API_SOLVER_TRY_CATCH_END; @@ -3939,8 +3910,8 @@ Term Solver::mkUniverseSet(Sort sort) const CVC4_API_ARG_CHECK_EXPECTED(!sort.isNull(), sort) << "non-null sort"; CVC4_API_SOLVER_CHECK_SORT(sort); - Expr res = - d_exprMgr->mkNullaryOperator(*sort.d_type, CVC4::kind::UNIVERSE_SET); + Node res = getNodeManager()->mkNullaryOperator(*sort.d_type, + CVC4::kind::UNIVERSE_SET); // TODO(#2771): Reenable? // (void)res->getType(true); /* kick off type checking */ return Term(this, res); @@ -3990,7 +3961,7 @@ Term Solver::mkConstArray(Sort sort, Term val) const n = n[0]; } Term res = mkValHelper<CVC4::ArrayStoreAll>( - CVC4::ArrayStoreAll(TypeNode::fromType(*sort.d_type), n)); + CVC4::ArrayStoreAll(*sort.d_type, n)); return res; CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4071,7 +4042,7 @@ Term Solver::mkUninterpretedConst(Sort sort, int32_t index) const CVC4_API_SOLVER_CHECK_SORT(sort); return mkValHelper<CVC4::UninterpretedConstant>( - CVC4::UninterpretedConstant(TypeNode::fromType(*sort.d_type), index)); + CVC4::UninterpretedConstant(*sort.d_type, index)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4134,8 +4105,8 @@ Term Solver::mkConst(Sort sort, const std::string& symbol) const CVC4_API_ARG_CHECK_EXPECTED(!sort.isNull(), sort) << "non-null sort"; CVC4_API_SOLVER_CHECK_SORT(sort); - Expr res = symbol.empty() ? d_exprMgr->mkVar(*sort.d_type) - : d_exprMgr->mkVar(symbol, *sort.d_type); + Expr res = symbol.empty() ? d_exprMgr->mkVar(sort.d_type->toType()) + : d_exprMgr->mkVar(symbol, sort.d_type->toType()); (void)res.getType(true); /* kick off type checking */ return Term(this, res); @@ -4151,8 +4122,9 @@ Term Solver::mkVar(Sort sort, const std::string& symbol) const CVC4_API_ARG_CHECK_EXPECTED(!sort.isNull(), sort) << "non-null sort"; CVC4_API_SOLVER_CHECK_SORT(sort); - Expr res = symbol.empty() ? d_exprMgr->mkBoundVar(*sort.d_type) - : d_exprMgr->mkBoundVar(symbol, *sort.d_type); + Expr res = symbol.empty() + ? d_exprMgr->mkBoundVar(sort.d_type->toType()) + : d_exprMgr->mkBoundVar(symbol, sort.d_type->toType()); (void)res.getType(true); /* kick off type checking */ return Term(this, res); @@ -4776,7 +4748,7 @@ Sort Solver::declareDatatype( << "datatype constructor declaration associated to this solver object"; dtdecl.addConstructor(ctors[i]); } - return Sort(this, getNodeManager()->mkDatatypeType(*dtdecl.d_dtype).toType()); + return Sort(this, getNodeManager()->mkDatatypeType(*dtdecl.d_dtype)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4801,13 +4773,13 @@ Term Solver::declareFun(const std::string& symbol, << "first-class sort as function codomain sort"; CVC4_API_SOLVER_CHECK_SORT(sort); Assert(!sort.isFunction()); /* A function sort is not first-class. */ - Type type = *sort.d_type; + TypeNode type = *sort.d_type; if (!sorts.empty()) { - std::vector<Type> types = sortVectorToTypes(sorts); - type = d_exprMgr->mkFunctionType(types, type); + std::vector<TypeNode> types = sortVectorToTypeNodes(sorts); + type = getNodeManager()->mkFunctionType(types, type); } - return Term(this, d_exprMgr->mkVar(symbol, type)); + return Term(this, d_exprMgr->mkVar(symbol, type.toType())); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4817,8 +4789,11 @@ Term Solver::declareFun(const std::string& symbol, Sort Solver::declareSort(const std::string& symbol, uint32_t arity) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - if (arity == 0) return Sort(this, d_exprMgr->mkSort(symbol)); - return Sort(this, d_exprMgr->mkSortConstructor(symbol, arity)); + if (arity == 0) + { + return Sort(this, getNodeManager()->mkSort(symbol)); + } + return Sort(this, getNodeManager()->mkSortConstructor(symbol, arity)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4846,18 +4821,18 @@ Term Solver::defineFun(const std::string& symbol, bound_vars[i], i) << "a bound variable"; - CVC4::Type t = bound_vars[i].d_node->getType().toType(); + CVC4::TypeNode t = bound_vars[i].d_node->getType(); CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( t.isFirstClass(), "sort of parameter", bound_vars[i], i) << "first-class sort of parameter of defined function"; - domain_types.push_back(TypeNode::fromType(t)); + domain_types.push_back(t); } CVC4_API_SOLVER_CHECK_SORT(sort); CVC4_API_CHECK(sort == term.getSort()) << "Invalid sort of function body '" << term << "', expected '" << sort << "'"; NodeManager* nm = getNodeManager(); - TypeNode type = TypeNode::fromType(*sort.d_type); + TypeNode type = *sort.d_type; if (!domain_types.empty()) { type = nm->mkFunctionType(domain_types, type); @@ -4965,7 +4940,7 @@ Term Solver::defineFunRec(const std::string& symbol, << "'"; CVC4_API_SOLVER_CHECK_TERM(term); NodeManager* nm = getNodeManager(); - TypeNode type = TypeNode::fromType(*sort.d_type); + TypeNode type = *sort.d_type; if (!domain_types.empty()) { type = nm->mkFunctionType(domain_types, type); @@ -5354,8 +5329,8 @@ bool Solver::getInterpolant(Term conj, Grammar& g, Term& output) const CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4::ExprManagerScope exmgrs(*(d_exprMgr.get())); Node result; - bool success = d_smtEngine->getInterpol( - *conj.d_node, TypeNode::fromType(*g.resolve().d_type), result); + bool success = + d_smtEngine->getInterpol(*conj.d_node, *g.resolve().d_type, result); if (success) { output = Term(this, result); @@ -5383,8 +5358,8 @@ bool Solver::getAbduct(Term conj, Grammar& g, Term& output) const CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4::ExprManagerScope exmgrs(*(d_exprMgr.get())); Node result; - bool success = d_smtEngine->getAbduct( - *conj.d_node, TypeNode::fromType(*g.resolve().d_type), result); + bool success = + d_smtEngine->getAbduct(*conj.d_node, *g.resolve().d_type, result); if (success) { output = Term(this, result); @@ -5569,10 +5544,10 @@ Term Solver::mkSygusVar(Sort sort, const std::string& symbol) const CVC4_API_ARG_CHECK_NOT_NULL(sort); CVC4_API_SOLVER_CHECK_SORT(sort); - Expr res = d_exprMgr->mkBoundVar(symbol, *sort.d_type); + Node res = getNodeManager()->mkBoundVar(symbol, *sort.d_type); (void)res.getType(true); /* kick off type checking */ - d_smtEngine->declareSygusVar(symbol, res, TypeNode::fromType(*sort.d_type)); + d_smtEngine->declareSygusVar(symbol, res, *sort.d_type); return Term(this, res); @@ -5641,7 +5616,7 @@ Term Solver::synthInv(const std::string& symbol, const std::vector<Term>& boundVars) const { return synthFunHelper( - symbol, boundVars, Sort(this, d_exprMgr->booleanType()), true); + symbol, boundVars, Sort(this, getNodeManager()->booleanType()), true); } Term Solver::synthInv(const std::string& symbol, @@ -5649,7 +5624,7 @@ Term Solver::synthInv(const std::string& symbol, Grammar& g) const { return synthFunHelper( - symbol, boundVars, Sort(this, d_exprMgr->booleanType()), true, &g); + symbol, boundVars, Sort(this, getNodeManager()->booleanType()), true, &g); } Term Solver::synthFunHelper(const std::string& symbol, @@ -5661,7 +5636,7 @@ Term Solver::synthFunHelper(const std::string& symbol, CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_NOT_NULL(sort); - std::vector<Type> varTypes; + std::vector<TypeNode> varTypes; for (size_t i = 0, n = boundVars.size(); i < n; ++i) { CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( @@ -5676,36 +5651,28 @@ Term Solver::synthFunHelper(const std::string& symbol, boundVars[i], i) << "a bound variable"; - varTypes.push_back(boundVars[i].d_node->getType().toType()); + varTypes.push_back(boundVars[i].d_node->getType()); } CVC4_API_SOLVER_CHECK_SORT(sort); if (g != nullptr) { - CVC4_API_CHECK(g->d_ntSyms[0].d_node->getType().toType() == *sort.d_type) + CVC4_API_CHECK(g->d_ntSyms[0].d_node->getType() == *sort.d_type) << "Invalid Start symbol for Grammar g, Expected Start's sort to be " << *sort.d_type << " but found " << g->d_ntSyms[0].d_node->getType(); } - Type funType = varTypes.empty() - ? *sort.d_type - : d_exprMgr->mkFunctionType(varTypes, *sort.d_type); + TypeNode funType = varTypes.empty() ? *sort.d_type + : getNodeManager()->mkFunctionType( + varTypes, *sort.d_type); - Node fun = getNodeManager()->mkBoundVar(symbol, TypeNode::fromType(funType)); + Node fun = getNodeManager()->mkBoundVar(symbol, funType); (void)fun.getType(true); /* kick off type checking */ - std::vector<Node> bvns; - for (const Term& t : boundVars) - { - bvns.push_back(*t.d_node); - } + std::vector<Node> bvns = termVectorToNodes(boundVars); d_smtEngine->declareSynthFun( - symbol, - fun, - TypeNode::fromType(g == nullptr ? funType : *g->resolve().d_type), - isInv, - bvns); + symbol, fun, g == nullptr ? funType : *g->resolve().d_type, isInv, bvns); return Term(this, fun); @@ -5744,21 +5711,21 @@ void Solver::addSygusInvConstraint(Term inv, CVC4_API_ARG_CHECK_EXPECTED(inv.d_node->getType().isFunction(), inv) << "a function"; - FunctionType invType = inv.d_node->getType().toType(); + TypeNode invType = inv.d_node->getType(); CVC4_API_ARG_CHECK_EXPECTED(invType.getRangeType().isBoolean(), inv) << "boolean range"; - CVC4_API_CHECK(pre.d_node->getType().toType() == invType) + CVC4_API_CHECK(pre.d_node->getType() == invType) << "Expected inv and pre to have the same sort"; - CVC4_API_CHECK(post.d_node->getType().toType() == invType) + CVC4_API_CHECK(post.d_node->getType() == invType) << "Expected inv and post to have the same sort"; - const std::vector<Type>& invArgTypes = invType.getArgTypes(); + const std::vector<TypeNode>& invArgTypes = invType.getArgTypes(); - std::vector<Type> expectedTypes; - expectedTypes.reserve(2 * invType.getArity() + 1); + std::vector<TypeNode> expectedTypes; + expectedTypes.reserve(2 * invArgTypes.size() + 1); for (size_t i = 0, n = invArgTypes.size(); i < 2 * n; i += 2) { @@ -5767,15 +5734,13 @@ void Solver::addSygusInvConstraint(Term inv, } expectedTypes.push_back(invType.getRangeType()); - FunctionType expectedTransType = d_exprMgr->mkFunctionType(expectedTypes); + TypeNode expectedTransType = getNodeManager()->mkFunctionType(expectedTypes); - CVC4_API_CHECK(trans.d_node->toExpr().getType() == expectedTransType) + CVC4_API_CHECK(trans.d_node->getType() == expectedTransType) << "Expected trans's sort to be " << invType; - d_smtEngine->assertSygusInvConstraint(inv.d_node->toExpr(), - pre.d_node->toExpr(), - trans.d_node->toExpr(), - post.d_node->toExpr()); + d_smtEngine->assertSygusInvConstraint( + *inv.d_node, *pre.d_node, *trans.d_node, *post.d_node); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -5907,7 +5872,7 @@ std::vector<Type> sortVectorToTypes(const std::vector<Sort>& sorts) std::vector<Type> types; for (size_t i = 0, ssize = sorts.size(); i < ssize; i++) { - types.push_back(sorts[i].getType()); + types.push_back(sorts[i].getTypeNode().toType()); } return types; } @@ -5917,17 +5882,17 @@ std::vector<TypeNode> sortVectorToTypeNodes(const std::vector<Sort>& sorts) std::vector<TypeNode> typeNodes; for (const Sort& sort : sorts) { - typeNodes.push_back(TypeNode::fromType(sort.getType())); + typeNodes.push_back(sort.getTypeNode()); } return typeNodes; } -std::set<Type> sortSetToTypes(const std::set<Sort>& sorts) +std::set<TypeNode> sortSetToTypeNodes(const std::set<Sort>& sorts) { - std::set<Type> types; + std::set<TypeNode> types; for (const Sort& s : sorts) { - types.insert(s.getType()); + types.insert(s.getTypeNode()); } return types; } @@ -5949,6 +5914,16 @@ std::vector<Sort> typeVectorToSorts(const Solver* slv, std::vector<Sort> sorts; for (size_t i = 0, tsize = types.size(); i < tsize; i++) { + sorts.push_back(Sort(slv, TypeNode::fromType(types[i]))); + } + return sorts; +} +std::vector<Sort> typeNodeVectorToSorts(const Solver* slv, + const std::vector<TypeNode>& types) +{ + std::vector<Sort> sorts; + for (size_t i = 0, tsize = types.size(); i < tsize; i++) + { sorts.push_back(Sort(slv, types[i])); } return sorts; |