summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoryoni206 <yoni206@users.noreply.github.com>2021-06-02 12:30:46 -0700
committerGitHub <noreply@github.com>2021-06-02 14:30:46 -0500
commit85a300898d7815973c064fe2c7b5b33473a71a5c (patch)
tree28d47a8e75881fae159197374b050de359ba5f6f
parentdde15bdbf752246fe7cb504df22261e0ad3835db (diff)
Adding getters to the python API and testing them (#6652)
This PR adds missing API functions from the cpp Term API to the python API. Corresponding tests are translated from term_black.cpp.
-rw-r--r--src/api/python/cvc5.pxd13
-rw-r--r--src/api/python/cvc5.pxi84
-rw-r--r--test/python/unit/api/test_term.py108
3 files changed, 191 insertions, 14 deletions
diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd
index fdcbfa997..2ad8cef5c 100644
--- a/src/api/python/cvc5.pxd
+++ b/src/api/python/cvc5.pxd
@@ -176,6 +176,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
Sort mkTupleSort(const vector[Sort]& sorts) except +
Term mkTerm(Op op) except +
Term mkTerm(Op op, const vector[Term]& children) except +
+ Term mkTuple(const vector[Sort]& sorts, const vector[Term]& terms) except +
Op mkOp(Kind kind) except +
Op mkOp(Kind kind, Kind k) except +
Op mkOp(Kind kind, const string& arg) except +
@@ -388,6 +389,8 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
Term operator*() except +
const_iterator begin() except +
const_iterator end() except +
+
+ bint isConstArray() except +
bint isBooleanValue() except +
bint getBooleanValue() except +
bint isStringValue() except +
@@ -398,6 +401,8 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
string getRealValue() except +
bint isBitVectorValue() except +
string getBitVectorValue(uint32_t base) except +
+ bint isAbstractValue() except +
+ string getAbstractValue() except +
bint isFloatingPointPosZero() except +
bint isFloatingPointNegZero() except +
bint isFloatingPointPosInf() except +
@@ -406,7 +411,15 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
bint isFloatingPointValue() except +
tuple[uint32_t, uint32_t, Term] getFloatingPointValue() except +
+ bint isSetValue() except +
+ set[Term] getSetValue() except +
+ bint isSequenceValue() except +
vector[Term] getSequenceValue() except +
+ bint isUninterpretedValue() except +
+ pair[Sort, int32_t] getUninterpretedValue() except +
+ bint isTupleValue() except +
+ vector[Term] getTupleValue() except +
+
cdef cppclass TermHashFunction:
TermHashFunction() except +
diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi
index 25ded76bb..8599a1cd1 100644
--- a/src/api/python/cvc5.pxi
+++ b/src/api/python/cvc5.pxi
@@ -658,6 +658,19 @@ cdef class Solver:
term.cterm = self.csolver.mkTerm((<Op?> op).cop, v)
return term
+ def mkTuple(self, sorts, terms):
+ cdef vector[c_Sort] csorts
+ cdef vector[c_Term] cterms
+
+ for s in sorts:
+ csorts.push_back((<Sort?> s).csort)
+ for s in terms:
+ cterms.push_back((<Term?> s).cterm)
+ cdef Term result = Term(self)
+ result.cterm = self.csolver.mkTuple(csorts, cterms)
+ return result
+
+
def mkOp(self, kind k, arg0=None, arg1 = None):
'''
Supports the following uses:
@@ -1609,19 +1622,6 @@ cdef class Term:
def isNull(self):
return self.cterm.isNull()
- def getConstArrayBase(self):
- cdef Term term = Term(self.solver)
- term.cterm = self.cterm.getConstArrayBase()
- return term
-
- def getSequenceValue(self):
- elems = []
- for e in self.cterm.getSequenceValue():
- term = Term(self.solver)
- term.cterm = e
- elems.append(term)
- return elems
-
def notTerm(self):
cdef Term term = Term(self.solver)
term.cterm = self.cterm.notTerm()
@@ -1657,6 +1657,14 @@ cdef class Term:
term.cterm = self.cterm.iteTerm(then_t.cterm, else_t.cterm)
return term
+ def isConstArray(self):
+ return self.cterm.isConstArray()
+
+ def getConstArrayBase(self):
+ cdef Term term = Term(self.solver)
+ term.cterm = self.cterm.getConstArrayBase()
+ return term
+
def isBooleanValue(self):
return self.cterm.isBooleanValue()
@@ -1673,7 +1681,12 @@ cdef class Term:
def isIntegerValue(self):
return self.cterm.isIntegerValue()
-
+ def isAbstractValue(self):
+ return self.cterm.isAbstractValue()
+
+ def getAbstractValue(self):
+ return self.cterm.getAbstractValue().decode()
+
def isFloatingPointPosZero(self):
return self.cterm.isFloatingPointPosZero()
@@ -1698,6 +1711,49 @@ cdef class Term:
term.cterm = get2(t)
return (get0(t), get1(t), term)
+ def isSetValue(self):
+ return self.cterm.isSetValue()
+
+ def getSetValue(self):
+ elems = set()
+ for e in self.cterm.getSetValue():
+ term = Term(self.solver)
+ term.cterm = e
+ elems.add(term)
+ return elems
+
+ def isSequenceValue(self):
+ return self.cterm.isSequenceValue()
+
+ def getSequenceValue(self):
+ elems = []
+ for e in self.cterm.getSequenceValue():
+ term = Term(self.solver)
+ term.cterm = e
+ elems.append(term)
+ return elems
+
+ def isUninterpretedValue(self):
+ return self.cterm.isUninterpretedValue()
+
+ def getUninterpretedValue(self):
+ cdef pair[c_Sort, int32_t] p = self.cterm.getUninterpretedValue()
+ cdef Sort sort = Sort(self.solver)
+ sort.csort = p.first
+ i = p.second
+ return (sort, i)
+
+ def isTupleValue(self):
+ return self.cterm.isTupleValue()
+
+ def getTupleValue(self):
+ elems = []
+ for e in self.cterm.getTupleValue():
+ term = Term(self.solver)
+ term.cterm = e
+ elems.append(term)
+ return elems
+
def getIntegerValue(self):
return int(self.cterm.getIntegerValue().decode())
diff --git a/test/python/unit/api/test_term.py b/test/python/unit/api/test_term.py
index 5603655c6..32813e17f 100644
--- a/test/python/unit/api/test_term.py
+++ b/test/python/unit/api/test_term.py
@@ -930,6 +930,114 @@ def test_term_children(solver):
tnull[0]
+def test_get_const_array_base(solver):
+ intsort = solver.getIntegerSort()
+ arrsort = solver.mkArraySort(intsort, intsort)
+ one = solver.mkInteger(1)
+ constarr = solver.mkConstArray(arrsort, one)
+
+ assert constarr.isConstArray()
+ assert one == constarr.getConstArrayBase()
+
+
+def test_get_abstract_value(solver):
+ v1 = solver.mkAbstractValue(1)
+ v2 = solver.mkAbstractValue("15")
+ v3 = solver.mkAbstractValue("18446744073709551617")
+
+ assert v1.isAbstractValue()
+ assert v2.isAbstractValue()
+ assert v3.isAbstractValue()
+ assert "1" == v1.getAbstractValue()
+ assert "15" == v2.getAbstractValue()
+ assert "18446744073709551617" == v3.getAbstractValue()
+
+
+def test_get_tuple(solver):
+ s1 = solver.getIntegerSort()
+ s2 = solver.getRealSort()
+ s3 = solver.getStringSort()
+
+ t1 = solver.mkInteger(15)
+ t2 = solver.mkReal(17, 25)
+ t3 = solver.mkString("abc")
+
+ tup = solver.mkTuple([s1, s2, s3], [t1, t2, t3])
+
+ assert tup.isTupleValue()
+ assert [t1, t2, t3] == tup.getTupleValue()
+
+
+def test_get_set(solver):
+ s = solver.mkSetSort(solver.getIntegerSort())
+
+ i1 = solver.mkInteger(5)
+ i2 = solver.mkInteger(7)
+
+ s1 = solver.mkEmptySet(s)
+ s2 = solver.mkTerm(kinds.Singleton, i1)
+ s3 = solver.mkTerm(kinds.Singleton, i1)
+ s4 = solver.mkTerm(kinds.Singleton, i2)
+ s5 = solver.mkTerm(kinds.Union, s2, solver.mkTerm(kinds.Union, s3, s4))
+
+ assert s1.isSetValue()
+ assert s2.isSetValue()
+ assert s3.isSetValue()
+ assert s4.isSetValue()
+ assert not s5.isSetValue()
+ s5 = solver.simplify(s5)
+ assert s5.isSetValue()
+
+ assert set([]) == s1.getSetValue()
+ assert set([i1]) == s2.getSetValue()
+ assert set([i1]) == s3.getSetValue()
+ assert set([i2]) == s4.getSetValue()
+ assert set([i1, i2]) == s5.getSetValue()
+
+
+def test_get_sequence(solver):
+ s = solver.mkSequenceSort(solver.getIntegerSort())
+
+ i1 = solver.mkInteger(5)
+ i2 = solver.mkInteger(7)
+
+ s1 = solver.mkEmptySequence(s)
+ s2 = solver.mkTerm(kinds.SeqUnit, i1)
+ s3 = solver.mkTerm(kinds.SeqUnit, i1)
+ s4 = solver.mkTerm(kinds.SeqUnit, i2)
+ s5 = solver.mkTerm(kinds.SeqConcat, s2,
+ solver.mkTerm(kinds.SeqConcat, s3, s4))
+
+ assert s1.isSequenceValue()
+ assert not s2.isSequenceValue()
+ assert not s3.isSequenceValue()
+ assert not s4.isSequenceValue()
+ assert not s5.isSequenceValue()
+
+ s2 = solver.simplify(s2)
+ s3 = solver.simplify(s3)
+ s4 = solver.simplify(s4)
+ s5 = solver.simplify(s5)
+
+ assert [] == s1.getSequenceValue()
+ assert [i1] == s2.getSequenceValue()
+ assert [i1] == s3.getSequenceValue()
+ assert [i2] == s4.getSequenceValue()
+ assert [i1, i1, i2] == s5.getSequenceValue()
+
+
+def test_get_uninterpreted_const(solver):
+ s = solver.mkUninterpretedSort("test")
+ t1 = solver.mkUninterpretedConst(s, 3)
+ t2 = solver.mkUninterpretedConst(s, 5)
+
+ assert t1.isUninterpretedValue()
+ assert t2.isUninterpretedValue()
+
+ assert (s, 3) == t1.getUninterpretedValue()
+ assert (s, 5) == t2.getUninterpretedValue()
+
+
def test_get_floating_point(solver):
bvval = solver.mkBitVector("0000110000000011")
fp = solver.mkFloatingPoint(5, 11, bvval)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback