summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/api/python/cvc5.pxd1
-rw-r--r--src/api/python/cvc5.pxi5
-rw-r--r--test/python/unit/api/test_datatype_api.py234
3 files changed, 240 insertions, 0 deletions
diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd
index 2ad8cef5c..fdc1872e7 100644
--- a/src/api/python/cvc5.pxd
+++ b/src/api/python/cvc5.pxd
@@ -75,6 +75,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
DatatypeSelector operator[](const string& name) except +
string getName() except +
Term getConstructorTerm() except +
+ Term getSpecializedConstructorTerm(const Sort& retSort) except +
Term getTesterTerm() except +
size_t getNumSelectors() except +
DatatypeSelector getSelector(const string& name) except +
diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi
index 8599a1cd1..3339543f3 100644
--- a/src/api/python/cvc5.pxi
+++ b/src/api/python/cvc5.pxi
@@ -196,6 +196,11 @@ cdef class DatatypeConstructor:
term.cterm = self.cdc.getConstructorTerm()
return term
+ def getSpecializedConstructorTerm(self, Sort retSort):
+ cdef Term term = Term(self.solver)
+ term.cterm = self.cdc.getSpecializedConstructorTerm(retSort.csort)
+ return term
+
def getTesterTerm(self):
cdef Term term = Term(self.solver)
term.cterm = self.cdc.getTesterTerm()
diff --git a/test/python/unit/api/test_datatype_api.py b/test/python/unit/api/test_datatype_api.py
index 708942e98..24a47bd76 100644
--- a/test/python/unit/api/test_datatype_api.py
+++ b/test/python/unit/api/test_datatype_api.py
@@ -294,3 +294,237 @@ def test_parametric_datatype(solver):
assert not pairIntReal.isSubsortOf(pairIntInt)
assert not pairRealInt.isSubsortOf(pairIntInt)
assert pairIntInt.isSubsortOf(pairIntInt)
+
+
+def test_datatype_simply_rec(solver):
+ # Create mutual datatypes corresponding to this definition block:
+ #
+ # DATATYPE
+ # wlist = leaf(data: list),
+ # list = cons(car: wlist, cdr: list) | nil,
+ # ns = elem(ndata: set(wlist)) | elemArray(ndata2: array(list, list))
+ # END
+
+ # Make unresolved types as placeholders
+ unresTypes = set([])
+ unresWList = solver.mkUninterpretedSort("wlist")
+ unresList = solver.mkUninterpretedSort("list")
+ unresNs = solver.mkUninterpretedSort("ns")
+ unresTypes.add(unresWList)
+ unresTypes.add(unresList)
+ unresTypes.add(unresNs)
+
+ wlist = solver.mkDatatypeDecl("wlist")
+ leaf = solver.mkDatatypeConstructorDecl("leaf")
+ leaf.addSelector("data", unresList)
+ wlist.addConstructor(leaf)
+
+ llist = solver.mkDatatypeDecl("list")
+ cons = solver.mkDatatypeConstructorDecl("cons")
+ cons.addSelector("car", unresWList)
+ cons.addSelector("cdr", unresList)
+ llist.addConstructor(cons)
+ nil = solver.mkDatatypeConstructorDecl("nil")
+ llist.addConstructor(nil)
+
+ ns = solver.mkDatatypeDecl("ns")
+ elem = solver.mkDatatypeConstructorDecl("elem")
+ elem.addSelector("ndata", solver.mkSetSort(unresWList))
+ ns.addConstructor(elem)
+ elemArray = solver.mkDatatypeConstructorDecl("elemArray")
+ elemArray.addSelector("ndata", solver.mkArraySort(unresList, unresList))
+ ns.addConstructor(elemArray)
+
+ dtdecls = [wlist, llist, ns]
+ # this is well-founded and has no nested recursion
+ dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+ assert len(dtsorts) == 3
+ assert dtsorts[0].getDatatype().isWellFounded()
+ assert dtsorts[1].getDatatype().isWellFounded()
+ assert dtsorts[2].getDatatype().isWellFounded()
+ assert not dtsorts[0].getDatatype().hasNestedRecursion()
+ assert not dtsorts[1].getDatatype().hasNestedRecursion()
+ assert not dtsorts[2].getDatatype().hasNestedRecursion()
+
+ # Create mutual datatypes corresponding to this definition block:
+ # DATATYPE
+ # ns2 = elem2(ndata: array(int,ns2)) | nil2
+ # END
+
+ unresTypes.clear()
+ unresNs2 = solver.mkUninterpretedSort("ns2")
+ unresTypes.add(unresNs2)
+
+ ns2 = solver.mkDatatypeDecl("ns2")
+ elem2 = solver.mkDatatypeConstructorDecl("elem2")
+ elem2.addSelector("ndata",
+ solver.mkArraySort(solver.getIntegerSort(), unresNs2))
+ ns2.addConstructor(elem2)
+ nil2 = solver.mkDatatypeConstructorDecl("nil2")
+ ns2.addConstructor(nil2)
+
+ dtdecls.clear()
+ dtdecls.append(ns2)
+
+ # this is not well-founded due to non-simple recursion
+ dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+ assert len(dtsorts) == 1
+ assert dtsorts[0].getDatatype()[0][0].getRangeSort().isArray()
+ assert dtsorts[0].getDatatype()[0][0].getRangeSort().getArrayElementSort() \
+ == dtsorts[0]
+ assert dtsorts[0].getDatatype().isWellFounded()
+ assert dtsorts[0].getDatatype().hasNestedRecursion()
+
+ # Create mutual datatypes corresponding to this definition block:
+ # DATATYPE
+ # list3 = cons3(car: ns3, cdr: list3) | nil3,
+ # ns3 = elem3(ndata: set(list3))
+ # END
+
+ unresTypes.clear()
+ unresNs3 = solver.mkUninterpretedSort("ns3")
+ unresTypes.add(unresNs3)
+ unresList3 = solver.mkUninterpretedSort("list3")
+ unresTypes.add(unresList3)
+
+ list3 = solver.mkDatatypeDecl("list3")
+ cons3 = solver.mkDatatypeConstructorDecl("cons3")
+ cons3.addSelector("car", unresNs3)
+ cons3.addSelector("cdr", unresList3)
+ list3.addConstructor(cons3)
+ nil3 = solver.mkDatatypeConstructorDecl("nil3")
+ list3.addConstructor(nil3)
+
+ ns3 = solver.mkDatatypeDecl("ns3")
+ elem3 = solver.mkDatatypeConstructorDecl("elem3")
+ elem3.addSelector("ndata", solver.mkSetSort(unresList3))
+ ns3.addConstructor(elem3)
+
+ dtdecls.clear()
+ dtdecls.append(list3)
+ dtdecls.append(ns3)
+
+ # both are well-founded and have nested recursion
+ dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+ assert len(dtsorts) == 2
+ assert dtsorts[0].getDatatype().isWellFounded()
+ assert dtsorts[1].getDatatype().isWellFounded()
+ assert dtsorts[0].getDatatype().hasNestedRecursion()
+ assert dtsorts[1].getDatatype().hasNestedRecursion()
+
+ # Create mutual datatypes corresponding to this definition block:
+ # DATATYPE
+ # list4 = cons(car: set(ns4), cdr: list4) | nil,
+ # ns4 = elem(ndata: list4)
+ # END
+ unresTypes.clear()
+ unresNs4 = solver.mkUninterpretedSort("ns4")
+ unresTypes.add(unresNs4)
+ unresList4 = solver.mkUninterpretedSort("list4")
+ unresTypes.add(unresList4)
+
+ list4 = solver.mkDatatypeDecl("list4")
+ cons4 = solver.mkDatatypeConstructorDecl("cons4")
+ cons4.addSelector("car", solver.mkSetSort(unresNs4))
+ cons4.addSelector("cdr", unresList4)
+ list4.addConstructor(cons4)
+ nil4 = solver.mkDatatypeConstructorDecl("nil4")
+ list4.addConstructor(nil4)
+
+ ns4 = solver.mkDatatypeDecl("ns4")
+ elem4 = solver.mkDatatypeConstructorDecl("elem3")
+ elem4.addSelector("ndata", unresList4)
+ ns4.addConstructor(elem4)
+
+ dtdecls.clear()
+ dtdecls.append(list4)
+ dtdecls.append(ns4)
+
+ # both are well-founded and have nested recursion
+ dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+ assert len(dtsorts) == 2
+ assert dtsorts[0].getDatatype().isWellFounded()
+ assert dtsorts[1].getDatatype().isWellFounded()
+ assert dtsorts[0].getDatatype().hasNestedRecursion()
+ assert dtsorts[1].getDatatype().hasNestedRecursion()
+
+ # Create mutual datatypes corresponding to this definition block:
+ # DATATYPE
+ # list5[X] = cons(car: X, cdr: list5[list5[X]]) | nil
+ # END
+ unresTypes.clear()
+ unresList5 = solver.mkSortConstructorSort("list5", 1)
+ unresTypes.add(unresList5)
+
+ v = []
+ x = solver.mkParamSort("X")
+ v.append(x)
+ list5 = solver.mkDatatypeDecl("list5", v)
+
+ args = [x]
+ urListX = unresList5.instantiate(args)
+ args[0] = urListX
+ urListListX = unresList5.instantiate(args)
+
+ cons5 = solver.mkDatatypeConstructorDecl("cons5")
+ cons5.addSelector("car", x)
+ cons5.addSelector("cdr", urListListX)
+ list5.addConstructor(cons5)
+ nil5 = solver.mkDatatypeConstructorDecl("nil5")
+ list5.addConstructor(nil5)
+
+ dtdecls.clear()
+ dtdecls.append(list5)
+
+ # well-founded and has nested recursion
+ dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+ assert len(dtsorts) == 1
+ assert dtsorts[0].getDatatype().isWellFounded()
+ assert dtsorts[0].getDatatype().hasNestedRecursion()
+
+
+def test_datatype_specialized_cons(solver):
+ # Create mutual datatypes corresponding to this definition block:
+ # DATATYPE
+ # plist[X] = pcons(car: X, cdr: plist[X]) | pnil
+ # END
+
+ # Make unresolved types as placeholders
+ unresTypes = set([])
+ unresList = solver.mkSortConstructorSort("plist", 1)
+ unresTypes.add(unresList)
+
+ v = []
+ x = solver.mkParamSort("X")
+ v.append(x)
+ plist = solver.mkDatatypeDecl("plist", v)
+
+ args = [x]
+ urListX = unresList.instantiate(args)
+
+ pcons = solver.mkDatatypeConstructorDecl("pcons")
+ pcons.addSelector("car", x)
+ pcons.addSelector("cdr", urListX)
+ plist.addConstructor(pcons)
+ nil5 = solver.mkDatatypeConstructorDecl("pnil")
+ plist.addConstructor(nil5)
+
+ dtdecls = [plist]
+
+ # make the datatype sorts
+ dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+ assert len(dtsorts) == 1
+ d = dtsorts[0].getDatatype()
+ nilc = d[0]
+
+ isort = solver.getIntegerSort()
+ iargs = [isort]
+ listInt = dtsorts[0].instantiate(iargs)
+
+ testConsTerm = Term(solver)
+ # get the specialized constructor term for list[Int]
+ testConsTerm = nilc.getSpecializedConstructorTerm(listInt)
+ assert testConsTerm != nilc.getConstructorTerm()
+ # error to get the specialized constructor term for Int
+ with pytest.raises(RuntimeError):
+ nilc.getSpecializedConstructorTerm(isort)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback