summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Sotoudeh <matthew@masot.net>2023-07-14 07:25:31 -0700
committerMatthew Sotoudeh <matthew@masot.net>2023-07-14 07:25:31 -0700
commit0b12ba0ca00f7cdfb50b614fb24b673fb7e4e322 (patch)
treec8dc0b54c6f9df2f10185fc93f4276f1a8535802
initial code
-rw-r--r--.gitignore3
-rw-r--r--README49
-rw-r--r--example/README5
-rw-r--r--example/program.py130
-rw-r--r--example/symabs.py62
-rw-r--r--example/symex.py23
-rw-r--r--example/test_files/cycle17
-rw-r--r--example/test_files/fib15
-rw-r--r--example/test_files/hasheq31
-rw-r--r--satispi.py449
10 files changed, 784 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..2791acf
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+__pycache__
+*.log
+pcache
diff --git a/README b/README
new file mode 100644
index 0000000..3463ad6
--- /dev/null
+++ b/README
@@ -0,0 +1,49 @@
+######## satispipy ########
+Satisfiability modulo theory with Portfolio and Incremental solving, in Python
+
+Dependencies:
+ - reasonably modern Python
+ - pexpect
+ - SMT-LIB solver binaries
+
+Features:
+ - automatic incremental solving (no need to push/pop)
+ - persistent caching
+ - automatic portfolio solving (use multiple solvers and take the fastest)
+ - support for heuristic quantifier instantiation
+ - simple, easy to hack
+
+Usage example:
+ see example/ for a simple symbolic execution engine using satispipy
+
+More details on automatic incremental solving:
+ to use:
+ call model([a, b, c])
+ call model([a, b, c, d, e])
+ call model([a, b, c, f, g])
+ call model([a, b, c])
+ satispipy will automatically insert push() and pop() commands as needed
+
+More details on persistent caching:
+ to use:
+ call model([a, b, c])
+ call model([e, f, g])
+ call model([a, b, c])
+ satispipy will return immediately for the third call
+ if you run the program again, it will also return immediately for all calls
+
+More details on portfolio solving:
+ portfolio is hard to make work with incremental solving.
+ suppose you're portfolio solving with solvers SMTA and SMTB,
+ and make incremental queries X, Y, Z
+ what if SMTA is fast for X and Y, but slow for Z? and SMTB is the opposite
+ we might have to wait for SMTA to solve X and Y before it gets to Z
+
+ instead, after each solve, we kill all but the single solver that solved it
+ so SMTB starts fresh on the calls for Y and Z
+ no need to wait
+
+Quantifier instantiation:
+ top-level assertions can be universally quantified
+ the quantifier can define its own instantiation heuristics
+ still works with incremental, persistent, and cached solving
diff --git a/example/README b/example/README
new file mode 100644
index 0000000..07ed675
--- /dev/null
+++ b/example/README
@@ -0,0 +1,5 @@
+From the parent folder:
+ $ python3 example/symex.py example/test_files/fib
+ $ python3 example/symex.py example/test_files/cycle
+ $ python3 example/symex.py example/test_files/hasheq
+ etc.
diff --git a/example/program.py b/example/program.py
new file mode 100644
index 0000000..b3e6afb
--- /dev/null
+++ b/example/program.py
@@ -0,0 +1,130 @@
+from satispi import *
+
+class Program:
+ def __init__(self, instructions):
+ self.instructions = instructions
+ self.registers = set()
+ for pc in self.instructions:
+ if pc[0] == "label": continue
+ if pc[0] == "ite":
+ if isinstance(pc[1], str):
+ self.registers.add(pc[1])
+ else:
+ self.registers.update(a for a in pc[1:] if isinstance(a, str))
+ self.registers = sorted(self.registers)
+
+ @staticmethod
+ def from_file(path):
+ lines = [l.strip() for l in open(path, "r").readlines()]
+ lines = [l for l in lines if l and not l.startswith(";")]
+ instructions = []
+ for l in lines:
+ opcode, *args = l.split()
+ if opcode != "label":
+ args = [int(a[1:]) if a[0] == "$" else a for a in args]
+ instructions.append((opcode, *args))
+ return Program(instructions)
+
+ def get_op(self, pcidx):
+ return self.instructions[pcidx]
+
+ def label_to_pcidx(self, label):
+ for i, pc in enumerate(self.instructions):
+ if pc == ("label", label):
+ return i
+ raise NotImplementedError
+
+class State:
+ def __init__(self, program, pcidx, smtlib, parent):
+ self.program = program
+ self.pcidx = pcidx
+ self.smtlib = smtlib.copy()
+ self.parent = parent
+ self.idx = (parent.idx + 1) if parent is not None else 0
+
+ def step(self):
+ smtlib = self.smtlib
+ opcode, *args = self.program.get_op(self.pcidx)
+ if opcode == "ite": # ite cond tlabel flabel
+ cond = self.to_term(args[0])
+ true_idx = self.program.label_to_pcidx(args[1])
+ false_idx = self.program.label_to_pcidx(args[2])
+ delta = self.nops({})
+ true_delta = Term("and", delta, Term("distinct", cond, Literal("(_ bv0 32)")))
+ false_delta = Term("and", delta, Term("=", cond, Literal("(_ bv0 32)")))
+ if args[1] == args[2]:
+ return [State(self.program, true_idx, smtlib + [delta], self)]
+ return [State(self.program, true_idx, smtlib + [true_delta], self),
+ State(self.program, false_idx, smtlib + [false_delta], self)]
+
+ new_pcidx = self.pcidx + 1
+ if opcode == "label":
+ delta = self.nops({})
+ elif opcode == "end":
+ delta = self.nops({})
+ new_pcidx = self.pcidx
+ elif opcode == "load": # load dstreg addr
+ to_reg = self.nto_term(args[0])
+ value = Term("select", self.heap(), self.to_term(args[1]))
+ delta = self.nops({args[0]})
+ delta = Term("and", delta, Term("=", to_reg, value))
+ elif opcode == "store": # store addr val
+ addr = self.to_term(args[0])
+ value = self.to_term(args[1])
+ new_heap = Term("store", self.heap(), addr, value)
+ delta = self.nops({"@heap"})
+ delta = Term("and", delta, Term("=", self.nheap(), new_heap))
+ elif opcode in ("add", "sub", "eq", "xor", "mul"): # op reg delta
+ op = {"add": "bvadd", "sub": "bvsub", "eq": "=",
+ "xor": "bvxor", "mul": "bvmul"}[opcode]
+ old = self.to_term(args[0])
+ val = self.to_term(args[1])
+ new = Term(op, old, val)
+ if op == "=":
+ new = Term("ite", new, Literal("(_ bv1 32)"), Literal("(_ bv0 32)"))
+ delta = self.nops({args[0]})
+ delta = Term("and", delta, Term("=", self.nreg(args[0]), new))
+ elif opcode == "mov": # mov dst src
+ val = self.to_term(args[1])
+ delta = self.nops({args[0]})
+ delta = Term("and", delta, Term("=", self.nreg(args[0]), val))
+ else: print(opcode); raise NotImplementedError
+ return [State(self.program, new_pcidx, self.smtlib + [delta], self)]
+
+ def has_model(self):
+ return Solver.singleton().model(self.smtlib) is not None
+
+ def heap(self):
+ return Variable(f"heap_{self.idx}",
+ "(Array (_ BitVec 32) (_ BitVec 32))")
+
+ def reg(self, reg):
+ return Variable(f"reg_{reg}_{self.idx}", "(_ BitVec 32)")
+
+ def to_term(self, arg):
+ if isinstance(arg, int):
+ return Literal(f"(_ bv{arg} 32)")
+ return self.reg(arg)
+
+ def nheap(self):
+ return Variable(f"heap_{self.idx + 1}",
+ "(Array (_ BitVec 32) (_ BitVec 32))")
+
+ def nreg(self, reg):
+ return Variable(f"reg_{reg}_{self.idx + 1}", "(_ BitVec 32)")
+
+ def nto_term(self, arg):
+ if isinstance(arg, int):
+ return Literal(f"(_ bv{arg} 32)")
+ return self.nreg(arg)
+
+ def nops(self, except_for=set()):
+ term = Literal("true")
+ for reg in self.program.registers:
+ if reg in except_for: continue
+ term = Term("and", term,
+ Term("=", self.nreg(reg), self.reg(reg)))
+ if "@heap" not in except_for:
+ term = Term("and", term,
+ Term("=", self.nheap(), self.heap()))
+ return term
diff --git a/example/symabs.py b/example/symabs.py
new file mode 100644
index 0000000..a74db27
--- /dev/null
+++ b/example/symabs.py
@@ -0,0 +1,62 @@
+import sys
+import itertools
+from satispi import *
+
+def symabs(smtlib, exprs, types, cutoffs):
+ if len(exprs) == 0: return []
+
+ if Solver.singleton().model(smtlib) is None:
+ return [frozenset() for _ in exprs]
+
+ smtlib = smtlib.copy()
+
+ vars_ = [Variable(f"____symabs{i}", types[i]) for i in range(len(exprs))]
+ smtlib += [Term("=", var, expr) for var, expr in zip(vars_, exprs)]
+
+ possible = [frozenset() for _ in exprs]
+ aim_for = max(len(possible) // 8, 1)
+ for i in itertools.count():
+ distinction = [bvize(Term("distinct", v, *map(wordlit, sorted(p))))
+ for v, p in zip(vars_, possible)
+ if p and p is not True]
+
+ if i > 0 and not distinction:
+ # everything is Top
+ return possible
+
+ if len(distinction) == 1:
+ n_deltas = Term("bvadd", *distinction, wordlit(0))
+ else:
+ n_deltas = Term("bvadd", *distinction)
+
+ # aim_for <= n distincts.
+ # if aim_for = 1 and it's unsat, then impossible to get any distincts
+ goal = []
+ if distinction:
+ goal = [Term("bvult", wordlit(aim_for - 1), n_deltas)]
+
+ use_vars = [v for v, p in zip(vars_, possible) if p is not True]
+ model = Solver.singleton().model(smtlib + goal, use_vars)
+
+ if model is None:
+ assert aim_for >= 0
+ if aim_for > 1:
+ aim_for = max(aim_for // 2, 1)
+ continue
+ return possible
+
+ prior = possible.copy()
+ for i, (v, c) in enumerate(zip(vars_, cutoffs)):
+ if possible[i] is True: continue
+ if model[v.name] is None: continue
+ possible[i] = possible[i] | frozenset({model[v.name]})
+ if c is not None and len(possible[i]) > c:
+ possible[i] = True
+ assert possible != prior, \
+ "Solver did not return a complete enough model to make progress"
+
+def wordlit(val):
+ return Literal(f"(_ bv{val} 32)")
+
+def bvize(term):
+ return Term("ite", term, wordlit(1), wordlit(0))
diff --git a/example/symex.py b/example/symex.py
new file mode 100644
index 0000000..c5d13be
--- /dev/null
+++ b/example/symex.py
@@ -0,0 +1,23 @@
+import sys
+from satispi import *
+from example.program import *
+from example.symabs import *
+
+program = Program.from_file(sys.argv[1])
+worklist = [State(program, 0, [], None)]
+while worklist:
+ state = worklist.pop(0)
+
+ # Process state
+ print("Processing state ...")
+ exprs = [state.reg(reg) for reg in program.registers]
+ types = ["(_ BitVec 32)" for _ in program.registers]
+ cutoffs = [5 for _ in exprs]
+ possible = symabs(state.smtlib, exprs, types, cutoffs)
+ for reg, values in zip(program.registers, possible):
+ print("\t", reg, "\t->\t", values)
+
+ # Execute
+ for next_state in state.step():
+ if next_state.has_model():
+ worklist.append(next_state)
diff --git a/example/test_files/cycle b/example/test_files/cycle
new file mode 100644
index 0000000..e5fc1d3
--- /dev/null
+++ b/example/test_files/cycle
@@ -0,0 +1,17 @@
+mov tortoise $0
+mov hare $0
+mov cycle_found $0
+label loop
+ ; tortise = tortoise->next
+ load tortoise tortoise
+ ; hare = hare->next->next
+ load hare hare
+ load hare hare
+ ; are_eq = (tortoise == hare)
+ mov are_eq tortoise
+ eq are_eq hare
+ ; cycle found?
+ ite are_eq cycle loop
+label cycle
+ mov cycle_found $1
+end
diff --git a/example/test_files/fib b/example/test_files/fib
new file mode 100644
index 0000000..f197e07
--- /dev/null
+++ b/example/test_files/fib
@@ -0,0 +1,15 @@
+store $0 $0
+store $1 $1
+mov idx $1
+label loop
+ ; fib = f(idx) + f(idx - 1)
+ load fib idx
+ sub idx $1
+ load fib2 idx
+ add idx $1
+ add fib fib2
+ ; memory[idx + 1] = fib
+ add idx $1
+ store idx fib
+ite $1 loop loop
+end
diff --git a/example/test_files/hasheq b/example/test_files/hasheq
new file mode 100644
index 0000000..702c17d
--- /dev/null
+++ b/example/test_files/hasheq
@@ -0,0 +1,31 @@
+; find inputs that hash to the same thing under
+; https://doc.riot-os.org/group__sys__hashes__djb2.html
+; and
+; https://doc.riot-os.org/group__sys__hashes__sdbm.html
+
+mov hash1 $123456
+mov hash2 $789012
+mov same_found $0
+
+mov i $0
+label loop
+ load next_word i
+
+ ; djb2 iteration
+ mul hash1 $33
+ xor hash1 next_word
+
+ ; sdbm iteration
+ mul hash2 $65599
+ add hash2 next_word
+
+ ; compare
+ mov are_eq hash1
+ eq are_eq hash2
+
+ add i $1
+ ite are_eq hashes_same loop
+
+label hashes_same
+ mov same_found $1
+end
diff --git a/satispi.py b/satispi.py
new file mode 100644
index 0000000..9aedf8c
--- /dev/null
+++ b/satispi.py
@@ -0,0 +1,449 @@
+import sys
+import subprocess
+import os
+import shutil
+import io
+import select
+import pexpect
+import pickle
+import hashlib
+import shutil
+import time
+from types import SimpleNamespace
+
+class SMTTerm:
+ __slots__ = ("tuple_", "hash_")
+ def to_smtlib(self): raise NotImplementedError
+ def variables(self): raise NotImplementedError
+ def dfs(self): yield self
+ def replace(self, old, new):
+ if self == old: return new
+ return self._replace(old, new)
+ def _replace(self, old, new): return self
+ def tuplify(self):
+ if self.tuple_ is None:
+ self.tuple_ = (str(type(self)),) + tuple(getattr(self, k) for k in self.__slots__
+ if not k.endswith("_"))
+ return self.tuple_
+ def __lt__(self, rhs):
+ return self.tuplify() < rhs.tuplify()
+ def __eq__(self, rhs):
+ if id(self) == id(rhs): return True
+ if not isinstance(rhs, SMTTerm): return False
+ return self.tuplify() == rhs.tuplify()
+ def __hash__(self):
+ if self.hash_ is None:
+ self.hash_ = hash(self.tuplify())
+ return self.hash_
+ def __str__(self): return self.to_smtlib()
+ def __repr__(self): return str(self)
+
+class Variable(SMTTerm):
+ __slots__ = ("name", "smt_type")
+ def __init__(self, name, smt_type):
+ assert isinstance(name, str)
+ assert isinstance(smt_type, (type(None), str))
+ self.name = name
+ self.smt_type = smt_type
+ self.tuple_, self.hash_ = None, None
+ def to_smtlib(self): return self.name
+ def variables(self): return frozenset({self})
+
+class Literal(SMTTerm):
+ __slots__ = ("string",)
+ def __init__(self, string):
+ self.string = string
+ self.tuple_, self.hash_ = None, None
+ def to_smtlib(self): return self.string
+ def variables(self): return frozenset()
+
+class Term(SMTTerm):
+ __slots__ = ("op", "args", "vars_")
+ def __init__(self, op, *args):
+ assert all(isinstance(a, SMTTerm) for a in args)
+ if isinstance(op, str): op = Literal(op)
+ self.op = op
+ self.args = tuple(args)
+ self.vars_ = None
+ self.tuple_, self.hash_ = None, None
+
+ def to_smtlib(self):
+ return f"({self.op.to_smtlib()} {' '.join(a.to_smtlib() for a in self.args)})"
+
+ def variables(self):
+ if self.vars_ is not None:
+ return self.vars_
+ if not self.args: return self.op.variables()
+ self.vars_ = frozenset.union(*[a.variables() for a in self.args]) | self.op.variables()
+ return self.vars_
+
+ def _replace(self, *args):
+ return Term(self.op.replace(*args), *[a.replace(*args) for a in self.args])
+
+ def op_str(self):
+ assert isinstance(self.op, Literal)
+ return self.op.string
+
+ def dfs(self):
+ yield self
+ yield from self.op.dfs()
+ for arg in self.args:
+ yield from arg.dfs()
+
+SATISPI = SimpleNamespace()
+SATISPI.SOLVERS = dict({
+ "yices": {"command": "yices_smt2 --interactive --incremental --bvconst-in-decimal",
+ "continue_prompt": False, "log": "yices.log"},
+ "cvc5": {"command": "/home/matthew/apps/cvc5-Linux --produce-models --interactive",
+ "continue_prompt": True, "log": "cvc5.log"},
+})
+SATISPI.MAIN_SOLVER = "yices"
+SATISPI.SINGLETON = None
+SATISPI.PCACHE = "pcache"
+SATISPI.LOGIC = "QF_ABV"
+class Solver:
+ def __init__(self):
+ # these are CUMULATIVE! can be strs and terms
+ self.user_stack = [[]]
+ # these must be strs
+ self.internal_stack = [[]]
+
+ self.growing_timeout = 30
+ self.command_timeout = 2
+
+ self.solvers = dict()
+ self.bring_up_solver(SATISPI.MAIN_SOLVER)
+
+ self.n_calls = 0
+
+ self.pcache = dict()
+ if SATISPI.PCACHE and os.path.isfile(SATISPI.PCACHE):
+ # https://stackoverflow.com/questions/11218477
+ self.pcache = pickle.load(open(SATISPI.PCACHE, "rb"))
+
+ def insert_pcache(self, key, model):
+ self.pcache[key] = model
+ if len(self.pcache) % 5 == 0:
+ self.flush_pcache()
+
+ def flush_pcache(self):
+ if not SATISPI.PCACHE: return
+ with open("/tmp/pcache.new", "wb") as f:
+ pickle.dump(self.pcache, f)
+ shutil.move("/tmp/pcache.new", SATISPI.PCACHE)
+
+ def bring_up_solver(self, which):
+ assert which not in self.solvers
+
+ kwargs = {"encoding": "utf-8", "echo": False, "timeout": None}
+ solver = pexpect.spawn(SATISPI.SOLVERS[which]["command"], **kwargs)
+ solver.delaybeforesend = None
+ solver.delayafterread = None
+
+ if SATISPI.SOLVERS[which]["log"]:
+ logname = SATISPI.SOLVERS[which]["log"]
+ solver.logfile_read = open(logname, "w")
+ solver.logfile_send = open(f"w{logname}", "w")
+
+ self.solvers[which] = solver
+
+ if SATISPI.LOGIC:
+ self.repl_cmd(which, f"(set-logic {SATISPI.LOGIC})")
+ for i, lines in enumerate(self.internal_stack):
+ if i != 0: lines = lines[len(self.internal_stack[i-1]):]
+ self.repl_cmd(which, "(push 1)")
+ for l in lines: self.repl_cmd(which, l)
+
+ def repl_cmd(self, which, cmd):
+ try:
+ i = self.solvers[which].expect_exact(["(error", f"{which}> "],
+ timeout=self.command_timeout)
+ except pexpect.TIMEOUT:
+ self.command_timeout *= 8
+ print("Timeout waiting for command to process")
+ raise SMTInternalError()
+ if i != 1: raise SMTInternalError()
+ self.sendline_maybe_break(which, cmd)
+
+ def sendline_maybe_break(self, which, line, no_prompt_last=False):
+ while line:
+ if ' ' in line[2048:]:
+ split_at = 2048 + line[2048:].index(' ')
+ assert split_at < 2048 + 128
+ else:
+ split_at = len(line)
+
+ if line[split_at:]:
+ if SATISPI.SOLVERS[which]["continue_prompt"]:
+ self.solvers[which].sendline(line[:split_at] + "\\")
+ i = self.solvers[which].expect_exact(["(error", "... > "])
+ assert i == 1
+ else:
+ self.solvers[which].sendline(line[:split_at])
+ else:
+ self.solvers[which].sendline(line[:split_at])
+ line = line[split_at:]
+
+ @staticmethod
+ def singleton():
+ if SATISPI.SINGLETON is None: SATISPI.SINGLETON = Solver()
+ return SATISPI.SINGLETON
+
+ def hash(self, assertions, for_vars):
+ assertions = '\n'.join(a if isinstance(a, str) else a.to_smtlib() for a in assertions)
+ assertions = assertions.encode("utf-8")
+ for_vars = '\n'.join(v.to_smtlib() for v in for_vars)
+ for_vars = for_vars.encode("utf-8")
+ return hashlib.sha256(assertions).hexdigest() + hashlib.sha256(for_vars).hexdigest()
+
+ def model(self, assertions, for_vars=[], reset_timeout=True):
+ global SOLVERS
+ if reset_timeout:
+ self.growing_timeout = 60
+ self.command_timeout = 2
+ try:
+ return self.model_(assertions, for_vars)
+ except (pexpect.exceptions.TIMEOUT, SMTInternalError):
+ print("Timed out (or error) while waiting on a solver to ack, re-trying ...")
+ for solver in self.solvers:
+ self.solvers[solver].terminate(True)
+ self.solvers = dict()
+ self.bring_up_solver(SATISPI.MAIN_SOLVER)
+ self.growing_timeout *= 2
+ return self.model(assertions, for_vars, reset_timeout=False)
+
+ def model_(self, assertions, for_vars=[]):
+ """ assertions can be Terms or strs """
+ key = self.hash(assertions, for_vars)
+ if key in self.pcache:
+ return self.pcache[key]
+ # raise NotImplementedError
+
+ og_assertions = assertions.copy()
+ og_vars = for_vars.copy()
+ if for_vars == []:
+ for_vars = [Literal("true")]
+
+ # (0) First just record the assertions before we modify them.
+ new_user_stack = assertions
+
+ # (1) Set the solver to an assertion level that is strictly below the
+ # current one.
+ for i, user_assertions in enumerate(self.user_stack):
+ if self.user_stack[i] != assertions[:len(self.user_stack[i])]:
+ # pop from this assertion level forward
+ pop_n = len(self.user_stack) - i
+
+ for solver in self.solvers:
+ self.repl_cmd(solver, f"(pop {pop_n})")
+
+ self.user_stack = self.user_stack[:i]
+ self.internal_stack = self.internal_stack[:i]
+ break
+
+ # (2) Fully apply any quantifiers.
+ quant_asserts = set()
+ for t in assertions:
+ if hasattr(t, "qinst"):
+ quant_asserts.update(t.qinst(assertions))
+
+ # (3) Then strip to just the new assertions & update the stacks
+ assertions = assertions[len(self.user_stack[-1]):]
+ self.user_stack.append(assertions.copy())
+ self.internal_stack.append(self.internal_stack[-1].copy())
+
+ # (4) collect declarations
+ assertions = assertions + sorted(quant_asserts)
+ variables = set()
+ for t in assertions:
+ if isinstance(t, SMTTerm):
+ variables.update(t.variables())
+ new_lines = []
+ for v in sorted(variables):
+ if v.smt_type is None: continue
+ new_lines.append(f"(declare-fun {v.to_smtlib()} () {v.smt_type})")
+
+ # (5) Convert assertions to strs
+ new_lines.extend(
+ t if isinstance(t, str) else f"(assert {t.to_smtlib()})"
+ for t in assertions if not hasattr(t, "qinst"))
+
+ # (6) Deduplicate lines
+ already_added = set(self.internal_stack[-1])
+ really_new = []
+ for l in new_lines:
+ if l in already_added:
+ continue
+ self.internal_stack[-1].append(l)
+ already_added.add(l)
+ really_new.append(l)
+ for solver in self.solvers:
+ self.repl_cmd(solver, "(push 1)")
+ for l in really_new:
+ for solver in self.solvers:
+ self.repl_cmd(solver, l)
+
+ # (7) ask the existing solvers to check satisfiability
+ assert len(self.solvers) == 1
+ self.n_calls += 1
+ if self.n_calls % 5 == 0 and SATISPI.MAIN_SOLVER not in self.solvers:
+ self.bring_up_solver(SATISPI.MAIN_SOLVER)
+
+ for solver in self.solvers:
+ self.repl_cmd(solver, "(check-sat)")
+
+ # give the original one a chance
+ solver = next(iter(self.solvers))
+ result = None
+ try:
+ result = self.solvers[solver].read_nonblocking(timeout=0.3)
+ result += self.solvers[solver].readline()
+ except pexpect.exceptions.TIMEOUT:
+ # now try bringing up the other
+ for solver in SOLVERS:
+ if solver not in self.solvers:
+ try:
+ self.bring_up_solver(solver)
+ self.repl_cmd(solver, "(check-sat)")
+ except SMTInternalError:
+ print("Got internal error bringing up", solver, "... ignoring")
+ self.solvers[solver].terminate(True)
+ self.solvers.pop(solver)
+ # and then poll on both of them
+ start_time = time.time()
+ while result is None:
+ if (time.time() - start_time) > self.growing_timeout:
+ raise SMTInternalError()
+ for solver in self.solvers:
+ try:
+ result = self.solvers[solver].read_nonblocking(timeout=0.0001)
+ result += self.solvers[solver].readline()
+ break
+ except pexpect.exceptions.TIMEOUT:
+ continue
+ # print("Winner:", next(iter(self.solvers)))
+ result = result.split()[-1]
+ for other in self.solvers:
+ if other != solver:
+ self.solvers[other].terminate(True)
+ self.solvers = dict({solver: self.solvers[solver]})
+
+ if result not in ("sat", "unsat"):
+ print(result)
+ raise NotImplementedError
+
+ if result == "unsat":
+ self.insert_pcache(key, None)
+ return None
+
+ model_line = f"(get-value ({' '.join(v.to_smtlib() for v in for_vars)}))"
+ self.repl_cmd(solver, model_line)
+
+ model = self.solvers[solver].readline()
+ while not are_parens_balanced(model):
+ model += self.solvers[solver].readline()
+ if "(error" in model:
+ # kill all the solvers
+ print("Error in SMT solver; re-trying ...")
+ for solver in self.solvers:
+ self.solvers[solver].terminate(True)
+ self.solvers = dict()
+ self.bring_up_solver(SATISPI.MAIN_SOLVER)
+ return self.model(og_assertions, og_vars)
+ model = model_parser(model)
+ self.insert_pcache(key, model)
+ return model
+
+ @staticmethod
+ def to_file(assertions, filename):
+ quant_lines = set()
+ for t in assertions:
+ if hasattr(t, "qinst"):
+ quant_lines.update(t.qinst(assertions))
+
+ variables = set()
+ for t in assertions + sorted(quant_lines):
+ if isinstance(t, SMTTerm):
+ variables.update(t.variables())
+
+ f = open(filename, "w")
+ if SATISPI.LOGIC:
+ f.write(f"(set-logic {SATISPI.LOGIC})\n")
+
+ for v in sorted(variables):
+ if v.smt_type is None: continue
+ f.write(f"(declare-fun {v.to_smtlib()} () {v.smt_type})\n")
+
+ # (5) Convert assertions to strs
+ for t in assertions:
+ if isinstance(t, str):
+ f.write(f"{t}\n")
+ else:
+ f.write(f"(assert {t.to_smtlib()})\n")
+ f.close()
+ print(filename)
+
+def model_parser(string):
+ string = string.strip()
+ modelstr = string[string.index('('):]
+ assert modelstr[-1] == ')'
+ modelstr = modelstr[1:-1].strip()
+ # now (e1 v1) (e2 v2) ...
+ parts = split_parens(modelstr)
+
+ model = dict()
+ for part in parts:
+ try:
+ name, value = split_parens(part)
+ except ValueError:
+ print(part)
+ print(string)
+ raise ValueError
+ if value.startswith("#b"):
+ model[name] = int(value[2:], 2)
+ elif value.startswith("#x"):
+ model[name] = int(value[2:], 16)
+ elif value == "true":
+ model[name] = True
+ elif value == "false":
+ model[name] = False
+ elif value.startswith("_ bv"):
+ model[name] = int(value[len("_ bv"):].split()[0])
+ elif value == "???":
+ model[name] = None
+ else:
+ print(part, value)
+ assert value == "???"
+ raise ValueError
+ return model
+
+def split_parens(string):
+ parts = []
+ while string:
+ paren_count = 0
+ for i, c in enumerate(string):
+ paren_count += (c == '(')
+ paren_count -= (c == ')')
+ if paren_count != 0: continue
+ if c in (' ', ')'):
+ if string[0] == '(':
+ parts.append(string[1:i].strip())
+ else:
+ parts.append(string[:i].strip())
+ string = string[i+1:].strip()
+ break
+ else:
+ if string:
+ parts.append(string)
+ string = ""
+ return parts
+
+def are_parens_balanced(string):
+ count = 0
+ for c in string:
+ count += (c == '(')
+ count -= (c == ')')
+ return count == 0
+
+class SMTInternalError(Exception):
+ pass
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback