summaryrefslogtreecommitdiff
path: root/satispi.py
diff options
context:
space:
mode:
authorMatthew Sotoudeh <matthew@masot.net>2023-07-14 07:25:31 -0700
committerMatthew Sotoudeh <matthew@masot.net>2023-07-14 07:25:31 -0700
commit0b12ba0ca00f7cdfb50b614fb24b673fb7e4e322 (patch)
treec8dc0b54c6f9df2f10185fc93f4276f1a8535802 /satispi.py
initial code
Diffstat (limited to 'satispi.py')
-rw-r--r--satispi.py449
1 files changed, 449 insertions, 0 deletions
diff --git a/satispi.py b/satispi.py
new file mode 100644
index 0000000..9aedf8c
--- /dev/null
+++ b/satispi.py
@@ -0,0 +1,449 @@
+import sys
+import subprocess
+import os
+import shutil
+import io
+import select
+import pexpect
+import pickle
+import hashlib
+import shutil
+import time
+from types import SimpleNamespace
+
+class SMTTerm:
+ __slots__ = ("tuple_", "hash_")
+ def to_smtlib(self): raise NotImplementedError
+ def variables(self): raise NotImplementedError
+ def dfs(self): yield self
+ def replace(self, old, new):
+ if self == old: return new
+ return self._replace(old, new)
+ def _replace(self, old, new): return self
+ def tuplify(self):
+ if self.tuple_ is None:
+ self.tuple_ = (str(type(self)),) + tuple(getattr(self, k) for k in self.__slots__
+ if not k.endswith("_"))
+ return self.tuple_
+ def __lt__(self, rhs):
+ return self.tuplify() < rhs.tuplify()
+ def __eq__(self, rhs):
+ if id(self) == id(rhs): return True
+ if not isinstance(rhs, SMTTerm): return False
+ return self.tuplify() == rhs.tuplify()
+ def __hash__(self):
+ if self.hash_ is None:
+ self.hash_ = hash(self.tuplify())
+ return self.hash_
+ def __str__(self): return self.to_smtlib()
+ def __repr__(self): return str(self)
+
+class Variable(SMTTerm):
+ __slots__ = ("name", "smt_type")
+ def __init__(self, name, smt_type):
+ assert isinstance(name, str)
+ assert isinstance(smt_type, (type(None), str))
+ self.name = name
+ self.smt_type = smt_type
+ self.tuple_, self.hash_ = None, None
+ def to_smtlib(self): return self.name
+ def variables(self): return frozenset({self})
+
+class Literal(SMTTerm):
+ __slots__ = ("string",)
+ def __init__(self, string):
+ self.string = string
+ self.tuple_, self.hash_ = None, None
+ def to_smtlib(self): return self.string
+ def variables(self): return frozenset()
+
+class Term(SMTTerm):
+ __slots__ = ("op", "args", "vars_")
+ def __init__(self, op, *args):
+ assert all(isinstance(a, SMTTerm) for a in args)
+ if isinstance(op, str): op = Literal(op)
+ self.op = op
+ self.args = tuple(args)
+ self.vars_ = None
+ self.tuple_, self.hash_ = None, None
+
+ def to_smtlib(self):
+ return f"({self.op.to_smtlib()} {' '.join(a.to_smtlib() for a in self.args)})"
+
+ def variables(self):
+ if self.vars_ is not None:
+ return self.vars_
+ if not self.args: return self.op.variables()
+ self.vars_ = frozenset.union(*[a.variables() for a in self.args]) | self.op.variables()
+ return self.vars_
+
+ def _replace(self, *args):
+ return Term(self.op.replace(*args), *[a.replace(*args) for a in self.args])
+
+ def op_str(self):
+ assert isinstance(self.op, Literal)
+ return self.op.string
+
+ def dfs(self):
+ yield self
+ yield from self.op.dfs()
+ for arg in self.args:
+ yield from arg.dfs()
+
+SATISPI = SimpleNamespace()
+SATISPI.SOLVERS = dict({
+ "yices": {"command": "yices_smt2 --interactive --incremental --bvconst-in-decimal",
+ "continue_prompt": False, "log": "yices.log"},
+ "cvc5": {"command": "/home/matthew/apps/cvc5-Linux --produce-models --interactive",
+ "continue_prompt": True, "log": "cvc5.log"},
+})
+SATISPI.MAIN_SOLVER = "yices"
+SATISPI.SINGLETON = None
+SATISPI.PCACHE = "pcache"
+SATISPI.LOGIC = "QF_ABV"
+class Solver:
+ def __init__(self):
+ # these are CUMULATIVE! can be strs and terms
+ self.user_stack = [[]]
+ # these must be strs
+ self.internal_stack = [[]]
+
+ self.growing_timeout = 30
+ self.command_timeout = 2
+
+ self.solvers = dict()
+ self.bring_up_solver(SATISPI.MAIN_SOLVER)
+
+ self.n_calls = 0
+
+ self.pcache = dict()
+ if SATISPI.PCACHE and os.path.isfile(SATISPI.PCACHE):
+ # https://stackoverflow.com/questions/11218477
+ self.pcache = pickle.load(open(SATISPI.PCACHE, "rb"))
+
+ def insert_pcache(self, key, model):
+ self.pcache[key] = model
+ if len(self.pcache) % 5 == 0:
+ self.flush_pcache()
+
+ def flush_pcache(self):
+ if not SATISPI.PCACHE: return
+ with open("/tmp/pcache.new", "wb") as f:
+ pickle.dump(self.pcache, f)
+ shutil.move("/tmp/pcache.new", SATISPI.PCACHE)
+
+ def bring_up_solver(self, which):
+ assert which not in self.solvers
+
+ kwargs = {"encoding": "utf-8", "echo": False, "timeout": None}
+ solver = pexpect.spawn(SATISPI.SOLVERS[which]["command"], **kwargs)
+ solver.delaybeforesend = None
+ solver.delayafterread = None
+
+ if SATISPI.SOLVERS[which]["log"]:
+ logname = SATISPI.SOLVERS[which]["log"]
+ solver.logfile_read = open(logname, "w")
+ solver.logfile_send = open(f"w{logname}", "w")
+
+ self.solvers[which] = solver
+
+ if SATISPI.LOGIC:
+ self.repl_cmd(which, f"(set-logic {SATISPI.LOGIC})")
+ for i, lines in enumerate(self.internal_stack):
+ if i != 0: lines = lines[len(self.internal_stack[i-1]):]
+ self.repl_cmd(which, "(push 1)")
+ for l in lines: self.repl_cmd(which, l)
+
+ def repl_cmd(self, which, cmd):
+ try:
+ i = self.solvers[which].expect_exact(["(error", f"{which}> "],
+ timeout=self.command_timeout)
+ except pexpect.TIMEOUT:
+ self.command_timeout *= 8
+ print("Timeout waiting for command to process")
+ raise SMTInternalError()
+ if i != 1: raise SMTInternalError()
+ self.sendline_maybe_break(which, cmd)
+
+ def sendline_maybe_break(self, which, line, no_prompt_last=False):
+ while line:
+ if ' ' in line[2048:]:
+ split_at = 2048 + line[2048:].index(' ')
+ assert split_at < 2048 + 128
+ else:
+ split_at = len(line)
+
+ if line[split_at:]:
+ if SATISPI.SOLVERS[which]["continue_prompt"]:
+ self.solvers[which].sendline(line[:split_at] + "\\")
+ i = self.solvers[which].expect_exact(["(error", "... > "])
+ assert i == 1
+ else:
+ self.solvers[which].sendline(line[:split_at])
+ else:
+ self.solvers[which].sendline(line[:split_at])
+ line = line[split_at:]
+
+ @staticmethod
+ def singleton():
+ if SATISPI.SINGLETON is None: SATISPI.SINGLETON = Solver()
+ return SATISPI.SINGLETON
+
+ def hash(self, assertions, for_vars):
+ assertions = '\n'.join(a if isinstance(a, str) else a.to_smtlib() for a in assertions)
+ assertions = assertions.encode("utf-8")
+ for_vars = '\n'.join(v.to_smtlib() for v in for_vars)
+ for_vars = for_vars.encode("utf-8")
+ return hashlib.sha256(assertions).hexdigest() + hashlib.sha256(for_vars).hexdigest()
+
+ def model(self, assertions, for_vars=[], reset_timeout=True):
+ global SOLVERS
+ if reset_timeout:
+ self.growing_timeout = 60
+ self.command_timeout = 2
+ try:
+ return self.model_(assertions, for_vars)
+ except (pexpect.exceptions.TIMEOUT, SMTInternalError):
+ print("Timed out (or error) while waiting on a solver to ack, re-trying ...")
+ for solver in self.solvers:
+ self.solvers[solver].terminate(True)
+ self.solvers = dict()
+ self.bring_up_solver(SATISPI.MAIN_SOLVER)
+ self.growing_timeout *= 2
+ return self.model(assertions, for_vars, reset_timeout=False)
+
+ def model_(self, assertions, for_vars=[]):
+ """ assertions can be Terms or strs """
+ key = self.hash(assertions, for_vars)
+ if key in self.pcache:
+ return self.pcache[key]
+ # raise NotImplementedError
+
+ og_assertions = assertions.copy()
+ og_vars = for_vars.copy()
+ if for_vars == []:
+ for_vars = [Literal("true")]
+
+ # (0) First just record the assertions before we modify them.
+ new_user_stack = assertions
+
+ # (1) Set the solver to an assertion level that is strictly below the
+ # current one.
+ for i, user_assertions in enumerate(self.user_stack):
+ if self.user_stack[i] != assertions[:len(self.user_stack[i])]:
+ # pop from this assertion level forward
+ pop_n = len(self.user_stack) - i
+
+ for solver in self.solvers:
+ self.repl_cmd(solver, f"(pop {pop_n})")
+
+ self.user_stack = self.user_stack[:i]
+ self.internal_stack = self.internal_stack[:i]
+ break
+
+ # (2) Fully apply any quantifiers.
+ quant_asserts = set()
+ for t in assertions:
+ if hasattr(t, "qinst"):
+ quant_asserts.update(t.qinst(assertions))
+
+ # (3) Then strip to just the new assertions & update the stacks
+ assertions = assertions[len(self.user_stack[-1]):]
+ self.user_stack.append(assertions.copy())
+ self.internal_stack.append(self.internal_stack[-1].copy())
+
+ # (4) collect declarations
+ assertions = assertions + sorted(quant_asserts)
+ variables = set()
+ for t in assertions:
+ if isinstance(t, SMTTerm):
+ variables.update(t.variables())
+ new_lines = []
+ for v in sorted(variables):
+ if v.smt_type is None: continue
+ new_lines.append(f"(declare-fun {v.to_smtlib()} () {v.smt_type})")
+
+ # (5) Convert assertions to strs
+ new_lines.extend(
+ t if isinstance(t, str) else f"(assert {t.to_smtlib()})"
+ for t in assertions if not hasattr(t, "qinst"))
+
+ # (6) Deduplicate lines
+ already_added = set(self.internal_stack[-1])
+ really_new = []
+ for l in new_lines:
+ if l in already_added:
+ continue
+ self.internal_stack[-1].append(l)
+ already_added.add(l)
+ really_new.append(l)
+ for solver in self.solvers:
+ self.repl_cmd(solver, "(push 1)")
+ for l in really_new:
+ for solver in self.solvers:
+ self.repl_cmd(solver, l)
+
+ # (7) ask the existing solvers to check satisfiability
+ assert len(self.solvers) == 1
+ self.n_calls += 1
+ if self.n_calls % 5 == 0 and SATISPI.MAIN_SOLVER not in self.solvers:
+ self.bring_up_solver(SATISPI.MAIN_SOLVER)
+
+ for solver in self.solvers:
+ self.repl_cmd(solver, "(check-sat)")
+
+ # give the original one a chance
+ solver = next(iter(self.solvers))
+ result = None
+ try:
+ result = self.solvers[solver].read_nonblocking(timeout=0.3)
+ result += self.solvers[solver].readline()
+ except pexpect.exceptions.TIMEOUT:
+ # now try bringing up the other
+ for solver in SOLVERS:
+ if solver not in self.solvers:
+ try:
+ self.bring_up_solver(solver)
+ self.repl_cmd(solver, "(check-sat)")
+ except SMTInternalError:
+ print("Got internal error bringing up", solver, "... ignoring")
+ self.solvers[solver].terminate(True)
+ self.solvers.pop(solver)
+ # and then poll on both of them
+ start_time = time.time()
+ while result is None:
+ if (time.time() - start_time) > self.growing_timeout:
+ raise SMTInternalError()
+ for solver in self.solvers:
+ try:
+ result = self.solvers[solver].read_nonblocking(timeout=0.0001)
+ result += self.solvers[solver].readline()
+ break
+ except pexpect.exceptions.TIMEOUT:
+ continue
+ # print("Winner:", next(iter(self.solvers)))
+ result = result.split()[-1]
+ for other in self.solvers:
+ if other != solver:
+ self.solvers[other].terminate(True)
+ self.solvers = dict({solver: self.solvers[solver]})
+
+ if result not in ("sat", "unsat"):
+ print(result)
+ raise NotImplementedError
+
+ if result == "unsat":
+ self.insert_pcache(key, None)
+ return None
+
+ model_line = f"(get-value ({' '.join(v.to_smtlib() for v in for_vars)}))"
+ self.repl_cmd(solver, model_line)
+
+ model = self.solvers[solver].readline()
+ while not are_parens_balanced(model):
+ model += self.solvers[solver].readline()
+ if "(error" in model:
+ # kill all the solvers
+ print("Error in SMT solver; re-trying ...")
+ for solver in self.solvers:
+ self.solvers[solver].terminate(True)
+ self.solvers = dict()
+ self.bring_up_solver(SATISPI.MAIN_SOLVER)
+ return self.model(og_assertions, og_vars)
+ model = model_parser(model)
+ self.insert_pcache(key, model)
+ return model
+
+ @staticmethod
+ def to_file(assertions, filename):
+ quant_lines = set()
+ for t in assertions:
+ if hasattr(t, "qinst"):
+ quant_lines.update(t.qinst(assertions))
+
+ variables = set()
+ for t in assertions + sorted(quant_lines):
+ if isinstance(t, SMTTerm):
+ variables.update(t.variables())
+
+ f = open(filename, "w")
+ if SATISPI.LOGIC:
+ f.write(f"(set-logic {SATISPI.LOGIC})\n")
+
+ for v in sorted(variables):
+ if v.smt_type is None: continue
+ f.write(f"(declare-fun {v.to_smtlib()} () {v.smt_type})\n")
+
+ # (5) Convert assertions to strs
+ for t in assertions:
+ if isinstance(t, str):
+ f.write(f"{t}\n")
+ else:
+ f.write(f"(assert {t.to_smtlib()})\n")
+ f.close()
+ print(filename)
+
+def model_parser(string):
+ string = string.strip()
+ modelstr = string[string.index('('):]
+ assert modelstr[-1] == ')'
+ modelstr = modelstr[1:-1].strip()
+ # now (e1 v1) (e2 v2) ...
+ parts = split_parens(modelstr)
+
+ model = dict()
+ for part in parts:
+ try:
+ name, value = split_parens(part)
+ except ValueError:
+ print(part)
+ print(string)
+ raise ValueError
+ if value.startswith("#b"):
+ model[name] = int(value[2:], 2)
+ elif value.startswith("#x"):
+ model[name] = int(value[2:], 16)
+ elif value == "true":
+ model[name] = True
+ elif value == "false":
+ model[name] = False
+ elif value.startswith("_ bv"):
+ model[name] = int(value[len("_ bv"):].split()[0])
+ elif value == "???":
+ model[name] = None
+ else:
+ print(part, value)
+ assert value == "???"
+ raise ValueError
+ return model
+
+def split_parens(string):
+ parts = []
+ while string:
+ paren_count = 0
+ for i, c in enumerate(string):
+ paren_count += (c == '(')
+ paren_count -= (c == ')')
+ if paren_count != 0: continue
+ if c in (' ', ')'):
+ if string[0] == '(':
+ parts.append(string[1:i].strip())
+ else:
+ parts.append(string[:i].strip())
+ string = string[i+1:].strip()
+ break
+ else:
+ if string:
+ parts.append(string)
+ string = ""
+ return parts
+
+def are_parens_balanced(string):
+ count = 0
+ for c in string:
+ count += (c == '(')
+ count -= (c == ')')
+ return count == 0
+
+class SMTInternalError(Exception):
+ pass
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback