summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoryoni206 <yoni206@users.noreply.github.com>2021-06-01 14:24:43 -0700
committerGitHub <noreply@github.com>2021-06-01 21:24:43 +0000
commit172573dba45f7d231ec06a3a3992f41cf794b75e (patch)
treec2c5b893f0555c68f8051831d70efcf6b8e2928f
parent7eff8fb5145752b100a9d04c834973e794d9a860 (diff)
Some additions to the datatypes python API (#6640)
This commit makes the following additions, in order to sync the python API with the cpp API. 1. adding `getName` functions to datatypes related classes 2. allowing `mkDatatypeSorts` with 1 or 2 arguments (previously allowed only 2). 3. In case there is a second argument to `mkDatatypeSorts`, we make sure it is a set. 4. Corresponding changes to the tests.
-rw-r--r--src/api/python/cvc5.pxd2
-rw-r--r--src/api/python/cvc5.pxi23
-rw-r--r--test/python/unit/api/test_datatype_api.py10
-rw-r--r--test/python/unit/api/test_solver.py8
4 files changed, 33 insertions, 10 deletions
diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd
index 87a646666..fdcbfa997 100644
--- a/src/api/python/cvc5.pxd
+++ b/src/api/python/cvc5.pxd
@@ -49,6 +49,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
DatatypeConstructor getConstructor(const string& name) except +
Term getConstructorTerm(const string& name) except +
DatatypeSelector getSelector(const string& name) except +
+ string getName() except +
size_t getNumConstructors() except +
bint isParametric() except +
bint isCodatatype() except +
@@ -100,6 +101,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
size_t getNumConstructors() except +
bint isParametric() except +
string toString() except +
+ string getName() except +
cdef cppclass DatatypeSelector:
diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi
index cd9e91e51..25ded76bb 100644
--- a/src/api/python/cvc5.pxi
+++ b/src/api/python/cvc5.pxi
@@ -1,4 +1,4 @@
-from collections import defaultdict
+from collections import defaultdict, Set
from fractions import Fraction
import sys
@@ -6,7 +6,7 @@ from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t
from libc.stddef cimport wchar_t
from libcpp.pair cimport pair
-from libcpp.set cimport set
+from libcpp.set cimport set as c_set
from libcpp.string cimport string
from libcpp.vector cimport vector
@@ -123,6 +123,9 @@ cdef class Datatype:
ds.cds = self.cd.getSelector(name.encode())
return ds
+ def getName(self):
+ return self.cd.getName().decode()
+
def getNumConstructors(self):
""":return: number of constructors."""
return self.cd.getNumConstructors()
@@ -259,6 +262,9 @@ cdef class DatatypeDecl:
def isParametric(self):
return self.cdd.isParametric()
+ def getName(self):
+ return self.cdd.getName().decode()
+
def __str__(self):
return self.cdd.toString().decode()
@@ -502,19 +508,24 @@ cdef class Solver:
sort.csort = self.csolver.mkDatatypeSort(dtypedecl.cdd)
return sort
- def mkDatatypeSorts(self, list dtypedecls, unresolvedSorts):
- sorts = []
+ def mkDatatypeSorts(self, list dtypedecls, unresolvedSorts = None):
+ """:return: A list of datatype sorts that correspond to dtypedecls and unresolvedSorts"""
+ if unresolvedSorts == None:
+ unresolvedSorts = set([])
+ else:
+ assert isinstance(unresolvedSorts, Set)
+ sorts = []
cdef vector[c_DatatypeDecl] decls
for decl in dtypedecls:
decls.push_back((<DatatypeDecl?> decl).cdd)
- cdef set[c_Sort] usorts
+ cdef c_set[c_Sort] usorts
for usort in unresolvedSorts:
usorts.insert((<Sort?> usort).csort)
csorts = self.csolver.mkDatatypeSorts(
- <const vector[c_DatatypeDecl]&> decls, <const set[c_Sort]&> usorts)
+ <const vector[c_DatatypeDecl]&> decls, <const c_set[c_Sort]&> usorts)
for csort in csorts:
sort = Sort(self)
sort.csort = csort
diff --git a/test/python/unit/api/test_datatype_api.py b/test/python/unit/api/test_datatype_api.py
index f0c1c0ea9..708942e98 100644
--- a/test/python/unit/api/test_datatype_api.py
+++ b/test/python/unit/api/test_datatype_api.py
@@ -84,6 +84,7 @@ def test_mk_datatype_sorts(solver):
for i in range(0, len(dtdecls)):
assert dtsorts[i].isDatatype()
assert not dtsorts[i].getDatatype().isFinite()
+ assert dtsorts[i].getDatatype().getName() == dtdecls[i].getName()
# verify the resolution was correct
dtTree = dtsorts[0].getDatatype()
dtcTreeNode = dtTree[0]
@@ -98,6 +99,8 @@ def test_mk_datatype_sorts(solver):
dtdeclsBad = []
emptyD = solver.mkDatatypeDecl("emptyD")
dtdeclsBad.append(emptyD)
+ with pytest.raises(RuntimeError):
+ solver.mkDatatypeSorts(dtdeclsBad)
def test_datatype_structs(solver):
@@ -177,6 +180,8 @@ def test_datatype_names(solver):
# create datatype sort to test
dtypeSpec = solver.mkDatatypeDecl("list")
+ dtypeSpec.getName()
+ assert dtypeSpec.getName() == "list"
cons = solver.mkDatatypeConstructorDecl("cons")
cons.addSelector("head", intSort)
cons.addSelectorSelf("tail")
@@ -185,6 +190,7 @@ def test_datatype_names(solver):
dtypeSpec.addConstructor(nil)
dtypeSort = solver.mkDatatypeSort(dtypeSpec)
dt = dtypeSort.getDatatype()
+ assert dt.getName() == "list"
dt.getConstructor("nil")
dt["cons"]
with pytest.raises(RuntimeError):
@@ -209,6 +215,10 @@ def test_datatype_names(solver):
with pytest.raises(RuntimeError):
dt.getSelector("cons")
+ # possible to construct null datatype declarations if not using mkDatatypeDecl
+ with pytest.raises(RuntimeError):
+ DatatypeDecl(solver).getName()
+
def test_parametric_datatype(solver):
v = []
diff --git a/test/python/unit/api/test_solver.py b/test/python/unit/api/test_solver.py
index c7224022e..67174ad8e 100644
--- a/test/python/unit/api/test_solver.py
+++ b/test/python/unit/api/test_solver.py
@@ -143,19 +143,19 @@ def test_mk_datatype_sorts(solver):
dtypeSpec2.addConstructor(nil2)
decls = [dtypeSpec1, dtypeSpec2]
- solver.mkDatatypeSorts(decls, [])
+ solver.mkDatatypeSorts(decls, set([]))
with pytest.raises(RuntimeError):
- slv.mkDatatypeSorts(decls, [])
+ slv.mkDatatypeSorts(decls, set([]))
throwsDtypeSpec = solver.mkDatatypeDecl("list")
throwsDecls = [throwsDtypeSpec]
with pytest.raises(RuntimeError):
- solver.mkDatatypeSorts(throwsDecls, [])
+ solver.mkDatatypeSorts(throwsDecls, set([]))
# with unresolved sorts
unresList = solver.mkUninterpretedSort("ulist")
- unresSorts = [unresList]
+ unresSorts = set([unresList])
ulist = solver.mkDatatypeDecl("ulist")
ucons = solver.mkDatatypeConstructorDecl("ucons")
ucons.addSelector("car", unresList)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback