summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoryoni206 <yoni206@users.noreply.github.com>2021-05-30 17:56:14 -0700
committerGitHub <noreply@github.com>2021-05-31 00:56:14 +0000
commit6edc06b3fb6367e8366cab13340228e2bebfca1e (patch)
tree4b71a6f70f0efb963668ee3dc74c9f4714f84697
parent21511862f74c74a9c75da1de01e6b0e0a8120613 (diff)
Update `toPythonObj` to use new getters -- part 1 (#6623)
Following #6496 , this PR adds new getters to the python API, as well as tests for them. This makes toPythonObj simpler. A future PR will add more getters to the python API. Co-authored-by: Gereon Kremer nafur42@gmail.com
-rw-r--r--src/api/python/cvc5.pxd12
-rw-r--r--src/api/python/cvc5.pxi167
-rw-r--r--test/api/python/test_to_python_obj.py2
-rw-r--r--test/python/unit/api/test_term.py242
4 files changed, 272 insertions, 151 deletions
diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd
index 205b82918..ef65c9070 100644
--- a/src/api/python/cvc5.pxd
+++ b/src/api/python/cvc5.pxd
@@ -22,7 +22,10 @@ cdef extern from "<functional>" namespace "std" nogil:
cdef extern from "<string>" namespace "std":
cdef cppclass wstring:
+ wstring() except +
wstring(const wchar_t*, size_t) except +
+ const wchar_t* data() except +
+ size_t size() except +
cdef extern from "api/cpp/cvc5.h" namespace "cvc5":
cdef cppclass Options:
@@ -374,7 +377,16 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
Term operator*() except +
const_iterator begin() except +
const_iterator end() except +
+ bint isBooleanValue() except +
+ bint getBooleanValue() except +
+ bint isStringValue() except +
+ wstring getStringValue() except +
bint isIntegerValue() except +
+ string getIntegerValue() except +
+ bint isRealValue() except +
+ string getRealValue() except +
+ bint isBitVectorValue() except +
+ string getBitVectorValue(uint32_t base) except +
vector[Term] getSequenceValue() except +
cdef cppclass TermHashFunction:
diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi
index b2942e0b3..7731f4e71 100644
--- a/src/api/python/cvc5.pxi
+++ b/src/api/python/cvc5.pxi
@@ -33,6 +33,7 @@ from cvc5kinds cimport Kind as c_Kind
cdef extern from "Python.h":
wchar_t* PyUnicode_AsWideCharString(object, Py_ssize_t *)
+ object PyUnicode_FromWideChar(const wchar_t*, Py_ssize_t)
void PyMem_Free(void*)
################################## DECORATORS #################################
@@ -759,6 +760,47 @@ cdef class Solver:
term.cterm = self.csolver.mkUniverseSet(sort.csort)
return term
+ @expand_list_arg(num_req_args=0)
+ def mkBitVector(self, *args):
+ '''
+ Supports the following arguments:
+ Term mkBitVector(int size, int val=0)
+ Term mkBitVector(string val, int base = 2)
+ Term mkBitVector(int size, string val, int base)
+ '''
+ cdef Term term = Term(self)
+ if len(args) == 1:
+ size_or_val = args[0]
+ if isinstance(args[0], int):
+ size = args[0]
+ term.cterm = self.csolver.mkBitVector(<uint32_t> size)
+ else:
+ assert isinstance(args[0], str)
+ val = args[0]
+ term.cterm = self.csolver.mkBitVector(<const string&> str(val).encode())
+ elif len(args) == 2:
+ if isinstance(args[0], int):
+ size = args[0]
+ assert isinstance(args[1], int)
+ val = args[1]
+ term.cterm = self.csolver.mkBitVector(<uint32_t> size, <uint32_t> val)
+ else:
+ assert isinstance(args[0], str)
+ assert isinstance(args[1], int)
+ val = args[0]
+ base = args[1]
+ term.cterm = self.csolver.mkBitVector(<const string&> str(val).encode(), <uint32_t> base)
+ elif len(args) == 3:
+ assert isinstance(args[0], int)
+ assert isinstance(args[1], str)
+ assert isinstance(args[2], int)
+ size = args[0]
+ val = args[1]
+ base = args[2]
+ term.cterm = self.csolver.mkBitVector(<uint32_t> size, <const string&> str(val).encode(), <uint32_t> base)
+ return term
+
+
def mkBitVector(self, size_or_str, val = None):
cdef Term term = Term(self)
if isinstance(size_or_str, int):
@@ -1603,9 +1645,38 @@ cdef class Term:
term.cterm = self.cterm.iteTerm(then_t.cterm, else_t.cterm)
return term
- def isInteger(self):
+ def isBooleanValue(self):
+ return self.cterm.isBooleanValue()
+
+ def getBooleanValue(self):
+ return self.cterm.getBooleanValue()
+
+ def isStringValue(self):
+ return self.cterm.isStringValue()
+
+ def getStringValue(self):
+ cdef Py_ssize_t size
+ cdef c_wstring s = self.cterm.getStringValue()
+ return PyUnicode_FromWideChar(s.data(), s.size())
+
+ def isIntegerValue(self):
return self.cterm.isIntegerValue()
-
+
+ def getIntegerValue(self):
+ return int(self.cterm.getIntegerValue().decode())
+
+ def isRealValue(self):
+ return self.cterm.isRealValue()
+
+ def getRealValue(self):
+ return float(Fraction(self.cterm.getRealValue().decode()))
+
+ def isBitVectorValue(self):
+ return self.cterm.isBitVectorValue()
+
+ def getBitVectorValue(self, base = 2):
+ return self.cterm.getBitVectorValue(base).decode()
+
def toPythonObj(self):
'''
Converts a constant value Term to a Python object.
@@ -1615,61 +1686,23 @@ cdef class Term:
Int -- returns a Python int
Real -- returns a Python Fraction
BV -- returns a Python int (treats BV as unsigned)
+ String -- returns a Python Unicode string
Array -- returns a Python dict mapping indices to values
-- the constant base is returned as the default value
- String -- returns a Python Unicode string
'''
- string_repr = self.cterm.toString().decode()
- assert string_repr
- sort = self.getSort()
- res = None
- if sort.isBoolean():
- if string_repr == "true":
- res = True
- else:
- assert string_repr == "false"
- res = False
-
- elif sort.isInteger():
- updated_string_repr = string_repr.strip('()').replace(' ', '')
- try:
- res = int(updated_string_repr)
- except:
- raise ValueError("Failed to convert"
- " {} to an int".format(string_repr))
-
- elif sort.isReal():
- updated_string_repr = string_repr
- try:
- # rational format (/ a b) most likely
- # note: a or b could be negated: (- a)
- splits = [s.strip('()/')
- for s in updated_string_repr.strip('()/') \
- .replace('(- ', '(-').split()]
- assert len(splits) == 2
- num = int(splits[0])
- den = int(splits[1])
- res = Fraction(num, den)
- except:
- try:
- # could be exact: e.g., 1.0
- res = Fraction(updated_string_repr)
- except:
- raise ValueError("Failed to convert "
- "{} to a Fraction".format(string_repr))
-
- elif sort.isBitVector():
- # expecting format #b<bits>
- assert string_repr[:2] == "#b"
- python_bin_repr = "0" + string_repr[1:]
- try:
- res = int(python_bin_repr, 2)
- except:
- raise ValueError("Failed to convert bitvector "
- "{} to an int".format(string_repr))
-
- elif sort.isArray():
+ if self.isBooleanValue():
+ return self.getBooleanValue()
+ elif self.isIntegerValue():
+ return self.getIntegerValue()
+ elif self.isRealValue():
+ return self.getRealValue()
+ elif self.isBitVectorValue():
+ return int(self.getBitVectorValue(), 2)
+ elif self.isStringValue():
+ return self.getStringValue()
+ elif self.getSort().isArray():
+ res = None
keys = []
values = []
base_value = None
@@ -1696,33 +1729,7 @@ cdef class Term:
for k, v in zip(keys, values):
res[k] = v
- elif sort.isString():
- # Strip leading and trailing double quotes and replace double
- # double quotes by single quotes
- string_repr = string_repr[1:-1].replace('""', '"')
-
- # Convert escape sequences
- res = ''
- escape_prefix = '\\u{'
- i = 0
- while True:
- prev_i = i
- i = string_repr.find(escape_prefix, i)
- if i == -1:
- res += string_repr[prev_i:]
- break
-
- res += string_repr[prev_i:i]
- val = string_repr[i + len(escape_prefix):string_repr.find('}', i)]
- res += chr(int(val, 16))
- i += len(escape_prefix) + len(val) + 1
- else:
- raise ValueError("Cannot convert term {}"
- " of sort {} to Python object".format(string_repr,
- sort))
-
- assert res is not None
- return res
+ return res
# Generate rounding modes
diff --git a/test/api/python/test_to_python_obj.py b/test/api/python/test_to_python_obj.py
index 572453670..2ba685d50 100644
--- a/test/api/python/test_to_python_obj.py
+++ b/test/api/python/test_to_python_obj.py
@@ -115,4 +115,4 @@ def testGetValueReal():
xval = solver.getValue(x)
yval = solver.getValue(y)
assert xval.toPythonObj() == Fraction("6")
- assert yval.toPythonObj() == Fraction("8.33")
+ assert yval.toPythonObj() == float(Fraction("8.33"))
diff --git a/test/python/unit/api/test_term.py b/test/python/unit/api/test_term.py
index 936ff3e1c..2b6fd8fd6 100644
--- a/test/python/unit/api/test_term.py
+++ b/test/python/unit/api/test_term.py
@@ -898,80 +898,182 @@ def test_substitute(solver):
def test_term_compare(solver):
- t1 = solver.mkInteger(1)
- t2 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2))
- t3 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2))
- assert t2 >= t3
- assert t2 <= t3
- assert (t1 > t2) != (t1 < t2)
- assert (t1 > t2 or t1 == t2) == (t1 >= t2)
+ t1 = solver.mkInteger(1)
+ t2 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2))
+ t3 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2))
+ assert t2 >= t3
+ assert t2 <= t3
+ assert (t1 > t2) != (t1 < t2)
+ assert (t1 > t2 or t1 == t2) == (t1 >= t2)
+
def test_term_children(solver):
- # simple term 2+3
- two = solver.mkInteger(2)
- t1 = solver.mkTerm(kinds.Plus, two, solver.mkInteger(3))
- assert t1[0] == two
- assert t1.getNumChildren() == 2
- tnull = Term(solver)
- with pytest.raises(RuntimeError):
- tnull.getNumChildren()
-
- # apply term f(2)
- intSort = solver.getIntegerSort()
- fsort = solver.mkFunctionSort(intSort, intSort)
- f = solver.mkConst(fsort, "f")
- t2 = solver.mkTerm(kinds.ApplyUf, f, two)
- # due to our higher-order view of terms, we treat f as a child of kinds.ApplyUf
- assert t2.getNumChildren() == 2
- assert t2[0] == f
- assert t2[1] == two
- with pytest.raises(RuntimeError):
- tnull[0]
+ # simple term 2+3
+ two = solver.mkInteger(2)
+ t1 = solver.mkTerm(kinds.Plus, two, solver.mkInteger(3))
+ assert t1[0] == two
+ assert t1.getNumChildren() == 2
+ tnull = Term(solver)
+ with pytest.raises(RuntimeError):
+ tnull.getNumChildren()
+
+ # apply term f(2)
+ intSort = solver.getIntegerSort()
+ fsort = solver.mkFunctionSort(intSort, intSort)
+ f = solver.mkConst(fsort, "f")
+ t2 = solver.mkTerm(kinds.ApplyUf, f, two)
+ # due to our higher-order view of terms, we treat f as a child of kinds.ApplyUf
+ assert t2.getNumChildren() == 2
+ assert t2[0] == f
+ assert t2[1] == two
+ with pytest.raises(RuntimeError):
+ tnull[0]
+
def test_is_integer(solver):
- int1 = solver.mkInteger("-18446744073709551616")
- int2 = solver.mkInteger("-18446744073709551615")
- int3 = solver.mkInteger("-4294967296")
- int4 = solver.mkInteger("-4294967295")
- int5 = solver.mkInteger("-10")
- int6 = solver.mkInteger("0")
- int7 = solver.mkInteger("10")
- int8 = solver.mkInteger("4294967295")
- int9 = solver.mkInteger("4294967296")
- int10 = solver.mkInteger("18446744073709551615")
- int11 = solver.mkInteger("18446744073709551616")
- int12 = solver.mkInteger("-0")
-
- with pytest.raises(RuntimeError):
- solver.mkInteger("")
- with pytest.raises(RuntimeError):
- solver.mkInteger("-")
- with pytest.raises(RuntimeError):
- solver.mkInteger("-1-")
- with pytest.raises(RuntimeError):
- solver.mkInteger("0.0")
- with pytest.raises(RuntimeError):
- solver.mkInteger("-0.1")
- with pytest.raises(RuntimeError):
- solver.mkInteger("012")
- with pytest.raises(RuntimeError):
- solver.mkInteger("0000")
- with pytest.raises(RuntimeError):
- solver.mkInteger("-01")
- with pytest.raises(RuntimeError):
- solver.mkInteger("-00")
-
- assert int1.isInteger()
- assert int2.isInteger()
- assert int3.isInteger()
- assert int4.isInteger()
- assert int5.isInteger()
- assert int6.isInteger()
- assert int7.isInteger()
- assert int8.isInteger()
- assert int9.isInteger()
- assert int10.isInteger()
- assert int11.isInteger()
+ int1 = solver.mkInteger("-18446744073709551616")
+ int2 = solver.mkInteger("-18446744073709551615")
+ int3 = solver.mkInteger("-4294967296")
+ int4 = solver.mkInteger("-4294967295")
+ int5 = solver.mkInteger("-10")
+ int6 = solver.mkInteger("0")
+ int7 = solver.mkInteger("10")
+ int8 = solver.mkInteger("4294967295")
+ int9 = solver.mkInteger("4294967296")
+ int10 = solver.mkInteger("18446744073709551615")
+ int11 = solver.mkInteger("18446744073709551616")
+ int12 = solver.mkInteger("-0")
+
+ with pytest.raises(RuntimeError):
+ solver.mkInteger("")
+ with pytest.raises(RuntimeError):
+ solver.mkInteger("-")
+ with pytest.raises(RuntimeError):
+ solver.mkInteger("-1-")
+ with pytest.raises(RuntimeError):
+ solver.mkInteger("0.0")
+ with pytest.raises(RuntimeError):
+ solver.mkInteger("-0.1")
+ with pytest.raises(RuntimeError):
+ solver.mkInteger("012")
+ with pytest.raises(RuntimeError):
+ solver.mkInteger("0000")
+ with pytest.raises(RuntimeError):
+ solver.mkInteger("-01")
+ with pytest.raises(RuntimeError):
+ solver.mkInteger("-00")
+
+ assert int1.isIntegerValue()
+ assert int2.isIntegerValue()
+ assert int3.isIntegerValue()
+ assert int4.isIntegerValue()
+ assert int5.isIntegerValue()
+ assert int6.isIntegerValue()
+ assert int7.isIntegerValue()
+ assert int8.isIntegerValue()
+ assert int9.isIntegerValue()
+ assert int10.isIntegerValue()
+ assert int11.isIntegerValue()
+
+ assert int1.getIntegerValue() == -18446744073709551616
+ assert int2.getIntegerValue() == -18446744073709551615
+ assert int3.getIntegerValue() == -4294967296
+ assert int4.getIntegerValue() == -4294967295
+ assert int5.getIntegerValue() == -10
+ assert int6.getIntegerValue() == 0
+ assert int7.getIntegerValue() == 10
+ assert int8.getIntegerValue() == 4294967295
+ assert int9.getIntegerValue() == 4294967296
+ assert int10.getIntegerValue() == 18446744073709551615
+ assert int11.getIntegerValue() == 18446744073709551616
+
+
+def test_get_string(solver):
+ s1 = solver.mkString("abcde")
+ assert s1.isStringValue()
+ assert s1.getStringValue() == str("abcde")
+
+
+def test_get_real(solver):
+ real1 = solver.mkReal("0")
+ real2 = solver.mkReal(".0")
+ real3 = solver.mkReal("-17")
+ real4 = solver.mkReal("-3/5")
+ real5 = solver.mkReal("12.7")
+ real6 = solver.mkReal("1/4294967297")
+ real7 = solver.mkReal("4294967297")
+ real8 = solver.mkReal("1/18446744073709551617")
+ real9 = solver.mkReal("18446744073709551617")
+
+ assert real1.isRealValue()
+ assert real2.isRealValue()
+ assert real3.isRealValue()
+ assert real4.isRealValue()
+ assert real5.isRealValue()
+ assert real6.isRealValue()
+ assert real7.isRealValue()
+ assert real8.isRealValue()
+ assert real9.isRealValue()
+
+ assert 0 == real1.getRealValue()
+ assert 0 == real2.getRealValue()
+ assert -17 == real3.getRealValue()
+ assert -3/5 == real4.getRealValue()
+ assert 127/10 == real5.getRealValue()
+ assert 1/4294967297 == real6.getRealValue()
+ assert 4294967297 == real7.getRealValue()
+ assert 1/18446744073709551617 == real8.getRealValue()
+ assert float(18446744073709551617) == real9.getRealValue()
+
+
+def test_get_boolean(solver):
+ b1 = solver.mkBoolean(True)
+ b2 = solver.mkBoolean(False)
+
+ assert b1.isBooleanValue()
+ assert b2.isBooleanValue()
+ assert b1.getBooleanValue()
+ assert not b2.getBooleanValue()
+
+
+def test_get_bit_vector(solver):
+ b1 = solver.mkBitVector(8, 15)
+ b2 = solver.mkBitVector("00001111", 2)
+ b3 = solver.mkBitVector("15", 10)
+ b4 = solver.mkBitVector("0f", 16)
+ b5 = solver.mkBitVector(8, "00001111", 2)
+ b6 = solver.mkBitVector(8, "15", 10)
+ b7 = solver.mkBitVector(8, "0f", 16)
+
+ assert b1.isBitVectorValue()
+ assert b2.isBitVectorValue()
+ assert b3.isBitVectorValue()
+ assert b4.isBitVectorValue()
+ assert b5.isBitVectorValue()
+ assert b6.isBitVectorValue()
+ assert b7.isBitVectorValue()
+
+ assert "00001111" == b1.getBitVectorValue(2)
+ assert "15" == b1.getBitVectorValue(10)
+ assert "f" == b1.getBitVectorValue(16)
+ assert "00001111" == b2.getBitVectorValue(2)
+ assert "15" == b2.getBitVectorValue(10)
+ assert "f" == b2.getBitVectorValue(16)
+ assert "1111" == b3.getBitVectorValue(2)
+ assert "15" == b3.getBitVectorValue(10)
+ assert "f" == b3.getBitVectorValue(16)
+ assert "00001111" == b4.getBitVectorValue(2)
+ assert "15" == b4.getBitVectorValue(10)
+ assert "f" == b4.getBitVectorValue(16)
+ assert "00001111" == b5.getBitVectorValue(2)
+ assert "15" == b5.getBitVectorValue(10)
+ assert "f" == b5.getBitVectorValue(16)
+ assert "00001111" == b6.getBitVectorValue(2)
+ assert "15" == b6.getBitVectorValue(10)
+ assert "f" == b6.getBitVectorValue(16)
+ assert "00001111" == b7.getBitVectorValue(2)
+ assert "15" == b7.getBitVectorValue(10)
+ assert "f" == b7.getBitVectorValue(16)
def test_const_array(solver):
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback