diff options
Diffstat (limited to 'src/expr')
-rw-r--r-- | src/expr/type_node.cpp | 86 |
1 files changed, 67 insertions, 19 deletions
diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index 0bf86240b..1ef5030ce 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -314,6 +314,15 @@ bool TypeNode::isSubtypeOf(TypeNode t) const { if(isSet() && t.isSet()) { return getSetElementType().isSubtypeOf(t.getSetElementType()); } + if (isFunction() && t.isFunction()) + { + if (!isComparableTo(t)) + { + // incomparable, not subtype + return false; + } + return getRangeType().isSubtypeOf(t.getRangeType()); + } // this should only return true for types T1, T2 where we handle equalities between T1 and T2 // (more cases go here, if we want to support such cases) return false; @@ -329,6 +338,11 @@ bool TypeNode::isComparableTo(TypeNode t) const { if(isSet() && t.isSet()) { return getSetElementType().isComparableTo(t.getSetElementType()); } + if (isFunction() && t.isFunction()) + { + // comparable if they have a common type node + return !leastCommonTypeNode(*this, t).isNull(); + } return false; } @@ -514,26 +528,60 @@ TypeNode TypeNode::commonTypeNode(TypeNode t0, TypeNode t1, bool isLeast) { // t0.getKind() == kind::TYPE_CONSTANT && // t1.getKind() == kind::TYPE_CONSTANT switch(t0.getKind()) { - case kind::BITVECTOR_TYPE: - case kind::FLOATINGPOINT_TYPE: - case kind::SORT_TYPE: - case kind::CONSTRUCTOR_TYPE: - case kind::SELECTOR_TYPE: - case kind::TESTER_TYPE: - case kind::FUNCTION_TYPE: - case kind::ARRAY_TYPE: - case kind::DATATYPE_TYPE: - case kind::PARAMETRIC_DATATYPE: - return TypeNode(); - case kind::SET_TYPE: { - // take the least common subtype of element types - TypeNode elementType; - if(t1.isSet() && !(elementType = commonTypeNode(t0[0], t1[0], isLeast)).isNull() ) { - return NodeManager::currentNM()->mkSetType(elementType); - } else { - return TypeNode(); + case kind::FUNCTION_TYPE: + { + if (t1.getKind() != kind::FUNCTION_TYPE) + { + return TypeNode(); + } + // must have equal arguments + std::vector<TypeNode> t0a = t0.getArgTypes(); + std::vector<TypeNode> t1a = t1.getArgTypes(); + if (t0a.size() != t1a.size()) + { + // different arities + return TypeNode(); + } + for (unsigned i = 0, nargs = t0a.size(); i < nargs; i++) + { + if (t0a[i] != t1a[i]) + { + // an argument is different + return TypeNode(); + } + } + TypeNode t0r = t0.getRangeType(); + TypeNode t1r = t1.getRangeType(); + TypeNode tr = commonTypeNode(t0r, t1r, isLeast); + std::vector<TypeNode> ftypes; + ftypes.insert(ftypes.end(), t0a.begin(), t0a.end()); + ftypes.push_back(tr); + return NodeManager::currentNM()->mkFunctionType(ftypes); + } + break; + case kind::BITVECTOR_TYPE: + case kind::FLOATINGPOINT_TYPE: + case kind::SORT_TYPE: + case kind::CONSTRUCTOR_TYPE: + case kind::SELECTOR_TYPE: + case kind::TESTER_TYPE: + case kind::ARRAY_TYPE: + case kind::DATATYPE_TYPE: + case kind::PARAMETRIC_DATATYPE: return TypeNode(); + case kind::SET_TYPE: + { + // take the least common subtype of element types + TypeNode elementType; + if (t1.isSet() + && !(elementType = commonTypeNode(t0[0], t1[0], isLeast)).isNull()) + { + return NodeManager::currentNM()->mkSetType(elementType); + } + else + { + return TypeNode(); + } } - } case kind::SEXPR_TYPE: Unimplemented() << "haven't implemented leastCommonType for symbolic expressions yet"; |