From 0b12ba0ca00f7cdfb50b614fb24b673fb7e4e322 Mon Sep 17 00:00:00 2001 From: Matthew Sotoudeh Date: Fri, 14 Jul 2023 07:25:31 -0700 Subject: initial code --- satispi.py | 449 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 449 insertions(+) create mode 100644 satispi.py (limited to 'satispi.py') diff --git a/satispi.py b/satispi.py new file mode 100644 index 0000000..9aedf8c --- /dev/null +++ b/satispi.py @@ -0,0 +1,449 @@ +import sys +import subprocess +import os +import shutil +import io +import select +import pexpect +import pickle +import hashlib +import shutil +import time +from types import SimpleNamespace + +class SMTTerm: + __slots__ = ("tuple_", "hash_") + def to_smtlib(self): raise NotImplementedError + def variables(self): raise NotImplementedError + def dfs(self): yield self + def replace(self, old, new): + if self == old: return new + return self._replace(old, new) + def _replace(self, old, new): return self + def tuplify(self): + if self.tuple_ is None: + self.tuple_ = (str(type(self)),) + tuple(getattr(self, k) for k in self.__slots__ + if not k.endswith("_")) + return self.tuple_ + def __lt__(self, rhs): + return self.tuplify() < rhs.tuplify() + def __eq__(self, rhs): + if id(self) == id(rhs): return True + if not isinstance(rhs, SMTTerm): return False + return self.tuplify() == rhs.tuplify() + def __hash__(self): + if self.hash_ is None: + self.hash_ = hash(self.tuplify()) + return self.hash_ + def __str__(self): return self.to_smtlib() + def __repr__(self): return str(self) + +class Variable(SMTTerm): + __slots__ = ("name", "smt_type") + def __init__(self, name, smt_type): + assert isinstance(name, str) + assert isinstance(smt_type, (type(None), str)) + self.name = name + self.smt_type = smt_type + self.tuple_, self.hash_ = None, None + def to_smtlib(self): return self.name + def variables(self): return frozenset({self}) + +class Literal(SMTTerm): + __slots__ = ("string",) + def __init__(self, string): + self.string = string + self.tuple_, self.hash_ = None, None + def to_smtlib(self): return self.string + def variables(self): return frozenset() + +class Term(SMTTerm): + __slots__ = ("op", "args", "vars_") + def __init__(self, op, *args): + assert all(isinstance(a, SMTTerm) for a in args) + if isinstance(op, str): op = Literal(op) + self.op = op + self.args = tuple(args) + self.vars_ = None + self.tuple_, self.hash_ = None, None + + def to_smtlib(self): + return f"({self.op.to_smtlib()} {' '.join(a.to_smtlib() for a in self.args)})" + + def variables(self): + if self.vars_ is not None: + return self.vars_ + if not self.args: return self.op.variables() + self.vars_ = frozenset.union(*[a.variables() for a in self.args]) | self.op.variables() + return self.vars_ + + def _replace(self, *args): + return Term(self.op.replace(*args), *[a.replace(*args) for a in self.args]) + + def op_str(self): + assert isinstance(self.op, Literal) + return self.op.string + + def dfs(self): + yield self + yield from self.op.dfs() + for arg in self.args: + yield from arg.dfs() + +SATISPI = SimpleNamespace() +SATISPI.SOLVERS = dict({ + "yices": {"command": "yices_smt2 --interactive --incremental --bvconst-in-decimal", + "continue_prompt": False, "log": "yices.log"}, + "cvc5": {"command": "/home/matthew/apps/cvc5-Linux --produce-models --interactive", + "continue_prompt": True, "log": "cvc5.log"}, +}) +SATISPI.MAIN_SOLVER = "yices" +SATISPI.SINGLETON = None +SATISPI.PCACHE = "pcache" +SATISPI.LOGIC = "QF_ABV" +class Solver: + def __init__(self): + # these are CUMULATIVE! can be strs and terms + self.user_stack = [[]] + # these must be strs + self.internal_stack = [[]] + + self.growing_timeout = 30 + self.command_timeout = 2 + + self.solvers = dict() + self.bring_up_solver(SATISPI.MAIN_SOLVER) + + self.n_calls = 0 + + self.pcache = dict() + if SATISPI.PCACHE and os.path.isfile(SATISPI.PCACHE): + # https://stackoverflow.com/questions/11218477 + self.pcache = pickle.load(open(SATISPI.PCACHE, "rb")) + + def insert_pcache(self, key, model): + self.pcache[key] = model + if len(self.pcache) % 5 == 0: + self.flush_pcache() + + def flush_pcache(self): + if not SATISPI.PCACHE: return + with open("/tmp/pcache.new", "wb") as f: + pickle.dump(self.pcache, f) + shutil.move("/tmp/pcache.new", SATISPI.PCACHE) + + def bring_up_solver(self, which): + assert which not in self.solvers + + kwargs = {"encoding": "utf-8", "echo": False, "timeout": None} + solver = pexpect.spawn(SATISPI.SOLVERS[which]["command"], **kwargs) + solver.delaybeforesend = None + solver.delayafterread = None + + if SATISPI.SOLVERS[which]["log"]: + logname = SATISPI.SOLVERS[which]["log"] + solver.logfile_read = open(logname, "w") + solver.logfile_send = open(f"w{logname}", "w") + + self.solvers[which] = solver + + if SATISPI.LOGIC: + self.repl_cmd(which, f"(set-logic {SATISPI.LOGIC})") + for i, lines in enumerate(self.internal_stack): + if i != 0: lines = lines[len(self.internal_stack[i-1]):] + self.repl_cmd(which, "(push 1)") + for l in lines: self.repl_cmd(which, l) + + def repl_cmd(self, which, cmd): + try: + i = self.solvers[which].expect_exact(["(error", f"{which}> "], + timeout=self.command_timeout) + except pexpect.TIMEOUT: + self.command_timeout *= 8 + print("Timeout waiting for command to process") + raise SMTInternalError() + if i != 1: raise SMTInternalError() + self.sendline_maybe_break(which, cmd) + + def sendline_maybe_break(self, which, line, no_prompt_last=False): + while line: + if ' ' in line[2048:]: + split_at = 2048 + line[2048:].index(' ') + assert split_at < 2048 + 128 + else: + split_at = len(line) + + if line[split_at:]: + if SATISPI.SOLVERS[which]["continue_prompt"]: + self.solvers[which].sendline(line[:split_at] + "\\") + i = self.solvers[which].expect_exact(["(error", "... > "]) + assert i == 1 + else: + self.solvers[which].sendline(line[:split_at]) + else: + self.solvers[which].sendline(line[:split_at]) + line = line[split_at:] + + @staticmethod + def singleton(): + if SATISPI.SINGLETON is None: SATISPI.SINGLETON = Solver() + return SATISPI.SINGLETON + + def hash(self, assertions, for_vars): + assertions = '\n'.join(a if isinstance(a, str) else a.to_smtlib() for a in assertions) + assertions = assertions.encode("utf-8") + for_vars = '\n'.join(v.to_smtlib() for v in for_vars) + for_vars = for_vars.encode("utf-8") + return hashlib.sha256(assertions).hexdigest() + hashlib.sha256(for_vars).hexdigest() + + def model(self, assertions, for_vars=[], reset_timeout=True): + global SOLVERS + if reset_timeout: + self.growing_timeout = 60 + self.command_timeout = 2 + try: + return self.model_(assertions, for_vars) + except (pexpect.exceptions.TIMEOUT, SMTInternalError): + print("Timed out (or error) while waiting on a solver to ack, re-trying ...") + for solver in self.solvers: + self.solvers[solver].terminate(True) + self.solvers = dict() + self.bring_up_solver(SATISPI.MAIN_SOLVER) + self.growing_timeout *= 2 + return self.model(assertions, for_vars, reset_timeout=False) + + def model_(self, assertions, for_vars=[]): + """ assertions can be Terms or strs """ + key = self.hash(assertions, for_vars) + if key in self.pcache: + return self.pcache[key] + # raise NotImplementedError + + og_assertions = assertions.copy() + og_vars = for_vars.copy() + if for_vars == []: + for_vars = [Literal("true")] + + # (0) First just record the assertions before we modify them. + new_user_stack = assertions + + # (1) Set the solver to an assertion level that is strictly below the + # current one. + for i, user_assertions in enumerate(self.user_stack): + if self.user_stack[i] != assertions[:len(self.user_stack[i])]: + # pop from this assertion level forward + pop_n = len(self.user_stack) - i + + for solver in self.solvers: + self.repl_cmd(solver, f"(pop {pop_n})") + + self.user_stack = self.user_stack[:i] + self.internal_stack = self.internal_stack[:i] + break + + # (2) Fully apply any quantifiers. + quant_asserts = set() + for t in assertions: + if hasattr(t, "qinst"): + quant_asserts.update(t.qinst(assertions)) + + # (3) Then strip to just the new assertions & update the stacks + assertions = assertions[len(self.user_stack[-1]):] + self.user_stack.append(assertions.copy()) + self.internal_stack.append(self.internal_stack[-1].copy()) + + # (4) collect declarations + assertions = assertions + sorted(quant_asserts) + variables = set() + for t in assertions: + if isinstance(t, SMTTerm): + variables.update(t.variables()) + new_lines = [] + for v in sorted(variables): + if v.smt_type is None: continue + new_lines.append(f"(declare-fun {v.to_smtlib()} () {v.smt_type})") + + # (5) Convert assertions to strs + new_lines.extend( + t if isinstance(t, str) else f"(assert {t.to_smtlib()})" + for t in assertions if not hasattr(t, "qinst")) + + # (6) Deduplicate lines + already_added = set(self.internal_stack[-1]) + really_new = [] + for l in new_lines: + if l in already_added: + continue + self.internal_stack[-1].append(l) + already_added.add(l) + really_new.append(l) + for solver in self.solvers: + self.repl_cmd(solver, "(push 1)") + for l in really_new: + for solver in self.solvers: + self.repl_cmd(solver, l) + + # (7) ask the existing solvers to check satisfiability + assert len(self.solvers) == 1 + self.n_calls += 1 + if self.n_calls % 5 == 0 and SATISPI.MAIN_SOLVER not in self.solvers: + self.bring_up_solver(SATISPI.MAIN_SOLVER) + + for solver in self.solvers: + self.repl_cmd(solver, "(check-sat)") + + # give the original one a chance + solver = next(iter(self.solvers)) + result = None + try: + result = self.solvers[solver].read_nonblocking(timeout=0.3) + result += self.solvers[solver].readline() + except pexpect.exceptions.TIMEOUT: + # now try bringing up the other + for solver in SOLVERS: + if solver not in self.solvers: + try: + self.bring_up_solver(solver) + self.repl_cmd(solver, "(check-sat)") + except SMTInternalError: + print("Got internal error bringing up", solver, "... ignoring") + self.solvers[solver].terminate(True) + self.solvers.pop(solver) + # and then poll on both of them + start_time = time.time() + while result is None: + if (time.time() - start_time) > self.growing_timeout: + raise SMTInternalError() + for solver in self.solvers: + try: + result = self.solvers[solver].read_nonblocking(timeout=0.0001) + result += self.solvers[solver].readline() + break + except pexpect.exceptions.TIMEOUT: + continue + # print("Winner:", next(iter(self.solvers))) + result = result.split()[-1] + for other in self.solvers: + if other != solver: + self.solvers[other].terminate(True) + self.solvers = dict({solver: self.solvers[solver]}) + + if result not in ("sat", "unsat"): + print(result) + raise NotImplementedError + + if result == "unsat": + self.insert_pcache(key, None) + return None + + model_line = f"(get-value ({' '.join(v.to_smtlib() for v in for_vars)}))" + self.repl_cmd(solver, model_line) + + model = self.solvers[solver].readline() + while not are_parens_balanced(model): + model += self.solvers[solver].readline() + if "(error" in model: + # kill all the solvers + print("Error in SMT solver; re-trying ...") + for solver in self.solvers: + self.solvers[solver].terminate(True) + self.solvers = dict() + self.bring_up_solver(SATISPI.MAIN_SOLVER) + return self.model(og_assertions, og_vars) + model = model_parser(model) + self.insert_pcache(key, model) + return model + + @staticmethod + def to_file(assertions, filename): + quant_lines = set() + for t in assertions: + if hasattr(t, "qinst"): + quant_lines.update(t.qinst(assertions)) + + variables = set() + for t in assertions + sorted(quant_lines): + if isinstance(t, SMTTerm): + variables.update(t.variables()) + + f = open(filename, "w") + if SATISPI.LOGIC: + f.write(f"(set-logic {SATISPI.LOGIC})\n") + + for v in sorted(variables): + if v.smt_type is None: continue + f.write(f"(declare-fun {v.to_smtlib()} () {v.smt_type})\n") + + # (5) Convert assertions to strs + for t in assertions: + if isinstance(t, str): + f.write(f"{t}\n") + else: + f.write(f"(assert {t.to_smtlib()})\n") + f.close() + print(filename) + +def model_parser(string): + string = string.strip() + modelstr = string[string.index('('):] + assert modelstr[-1] == ')' + modelstr = modelstr[1:-1].strip() + # now (e1 v1) (e2 v2) ... + parts = split_parens(modelstr) + + model = dict() + for part in parts: + try: + name, value = split_parens(part) + except ValueError: + print(part) + print(string) + raise ValueError + if value.startswith("#b"): + model[name] = int(value[2:], 2) + elif value.startswith("#x"): + model[name] = int(value[2:], 16) + elif value == "true": + model[name] = True + elif value == "false": + model[name] = False + elif value.startswith("_ bv"): + model[name] = int(value[len("_ bv"):].split()[0]) + elif value == "???": + model[name] = None + else: + print(part, value) + assert value == "???" + raise ValueError + return model + +def split_parens(string): + parts = [] + while string: + paren_count = 0 + for i, c in enumerate(string): + paren_count += (c == '(') + paren_count -= (c == ')') + if paren_count != 0: continue + if c in (' ', ')'): + if string[0] == '(': + parts.append(string[1:i].strip()) + else: + parts.append(string[:i].strip()) + string = string[i+1:].strip() + break + else: + if string: + parts.append(string) + string = "" + return parts + +def are_parens_balanced(string): + count = 0 + for c in string: + count += (c == '(') + count -= (c == ')') + return count == 0 + +class SMTInternalError(Exception): + pass -- cgit v1.2.3