import sys import subprocess import os import shutil import io import select import pexpect import pickle import hashlib import shutil import time from types import SimpleNamespace class SMTTerm: __slots__ = ("tuple_", "hash_") def to_smtlib(self): raise NotImplementedError def variables(self): raise NotImplementedError def dfs(self): yield self def replace(self, old, new): if self == old: return new return self._replace(old, new) def _replace(self, old, new): return self def tuplify(self): if self.tuple_ is None: self.tuple_ = (str(type(self)),) + tuple(getattr(self, k) for k in self.__slots__ if not k.endswith("_")) return self.tuple_ def __lt__(self, rhs): return self.tuplify() < rhs.tuplify() def __eq__(self, rhs): if id(self) == id(rhs): return True if not isinstance(rhs, SMTTerm): return False return self.tuplify() == rhs.tuplify() def __hash__(self): if self.hash_ is None: self.hash_ = hash(self.tuplify()) return self.hash_ def __str__(self): return self.to_smtlib() def __repr__(self): return str(self) class Variable(SMTTerm): __slots__ = ("name", "smt_type") def __init__(self, name, smt_type): assert isinstance(name, str) assert isinstance(smt_type, (type(None), str)) self.name = name self.smt_type = smt_type self.tuple_, self.hash_ = None, None def to_smtlib(self): return self.name def variables(self): return frozenset({self}) class Literal(SMTTerm): __slots__ = ("string",) def __init__(self, string): self.string = string self.tuple_, self.hash_ = None, None def to_smtlib(self): return self.string def variables(self): return frozenset() class Term(SMTTerm): __slots__ = ("op", "args", "vars_") def __init__(self, op, *args): assert all(isinstance(a, SMTTerm) for a in args) if isinstance(op, str): op = Literal(op) self.op = op self.args = tuple(args) self.vars_ = None self.tuple_, self.hash_ = None, None def to_smtlib(self): return f"({self.op.to_smtlib()} {' '.join(a.to_smtlib() for a in self.args)})" def variables(self): if self.vars_ is not None: return self.vars_ if not self.args: return self.op.variables() self.vars_ = frozenset.union(*[a.variables() for a in self.args]) | self.op.variables() return self.vars_ def _replace(self, *args): return Term(self.op.replace(*args), *[a.replace(*args) for a in self.args]) def op_str(self): assert isinstance(self.op, Literal) return self.op.string def dfs(self): yield self yield from self.op.dfs() for arg in self.args: yield from arg.dfs() SATISPI = SimpleNamespace() SATISPI.SOLVERS = dict({ "yices": {"command": "yices_smt2 --interactive --incremental --bvconst-in-decimal", "continue_prompt": False, "log": "yices.log"}, "cvc5": {"command": "/home/matthew/apps/cvc5-Linux --produce-models --interactive", "continue_prompt": True, "log": "cvc5.log"}, }) SATISPI.MAIN_SOLVER = "yices" SATISPI.SINGLETON = None SATISPI.PCACHE = "pcache" SATISPI.LOGIC = "QF_ABV" class Solver: def __init__(self): # these are CUMULATIVE! can be strs and terms self.user_stack = [[]] # these must be strs self.internal_stack = [[]] self.growing_timeout = 30 self.command_timeout = 2 self.solvers = dict() self.bring_up_solver(SATISPI.MAIN_SOLVER) self.n_calls = 0 self.pcache = dict() if SATISPI.PCACHE and os.path.isfile(SATISPI.PCACHE): # https://stackoverflow.com/questions/11218477 self.pcache = pickle.load(open(SATISPI.PCACHE, "rb")) def insert_pcache(self, key, model): self.pcache[key] = model if len(self.pcache) % 5 == 0: self.flush_pcache() def flush_pcache(self): if not SATISPI.PCACHE: return with open("/tmp/pcache.new", "wb") as f: pickle.dump(self.pcache, f) shutil.move("/tmp/pcache.new", SATISPI.PCACHE) def bring_up_solver(self, which): assert which not in self.solvers kwargs = {"encoding": "utf-8", "echo": False, "timeout": None} solver = pexpect.spawn(SATISPI.SOLVERS[which]["command"], **kwargs) solver.delaybeforesend = None solver.delayafterread = None if SATISPI.SOLVERS[which]["log"]: logname = SATISPI.SOLVERS[which]["log"] solver.logfile_read = open(logname, "w") solver.logfile_send = open(f"w{logname}", "w") self.solvers[which] = solver if SATISPI.LOGIC: self.repl_cmd(which, f"(set-logic {SATISPI.LOGIC})") for i, lines in enumerate(self.internal_stack): if i != 0: lines = lines[len(self.internal_stack[i-1]):] self.repl_cmd(which, "(push 1)") for l in lines: self.repl_cmd(which, l) def repl_cmd(self, which, cmd): try: i = self.solvers[which].expect_exact(["(error", f"{which}> "], timeout=self.command_timeout) except pexpect.TIMEOUT: self.command_timeout *= 8 print("Timeout waiting for command to process") raise SMTInternalError() if i != 1: raise SMTInternalError() self.sendline_maybe_break(which, cmd) def sendline_maybe_break(self, which, line, no_prompt_last=False): while line: if ' ' in line[2048:]: split_at = 2048 + line[2048:].index(' ') assert split_at < 2048 + 128 else: split_at = len(line) if line[split_at:]: if SATISPI.SOLVERS[which]["continue_prompt"]: self.solvers[which].sendline(line[:split_at] + "\\") i = self.solvers[which].expect_exact(["(error", "... > "]) assert i == 1 else: self.solvers[which].sendline(line[:split_at]) else: self.solvers[which].sendline(line[:split_at]) line = line[split_at:] @staticmethod def singleton(): if SATISPI.SINGLETON is None: SATISPI.SINGLETON = Solver() return SATISPI.SINGLETON def hash(self, assertions, for_vars): assertions = '\n'.join(a if isinstance(a, str) else a.to_smtlib() for a in assertions) assertions = assertions.encode("utf-8") for_vars = '\n'.join(v.to_smtlib() for v in for_vars) for_vars = for_vars.encode("utf-8") return hashlib.sha256(assertions).hexdigest() + hashlib.sha256(for_vars).hexdigest() def model(self, assertions, for_vars=[], reset_timeout=True): global SOLVERS if reset_timeout: self.growing_timeout = 60 self.command_timeout = 2 try: return self.model_(assertions, for_vars) except (pexpect.exceptions.TIMEOUT, SMTInternalError): print("Timed out (or error) while waiting on a solver to ack, re-trying ...") for solver in self.solvers: self.solvers[solver].terminate(True) self.solvers = dict() self.bring_up_solver(SATISPI.MAIN_SOLVER) self.growing_timeout *= 2 return self.model(assertions, for_vars, reset_timeout=False) def model_(self, assertions, for_vars=[]): """ assertions can be Terms or strs """ key = self.hash(assertions, for_vars) if key in self.pcache: return self.pcache[key] # raise NotImplementedError og_assertions = assertions.copy() og_vars = for_vars.copy() if for_vars == []: for_vars = [Literal("true")] # (0) First just record the assertions before we modify them. new_user_stack = assertions # (1) Set the solver to an assertion level that is strictly below the # current one. for i, user_assertions in enumerate(self.user_stack): if self.user_stack[i] != assertions[:len(self.user_stack[i])]: # pop from this assertion level forward pop_n = len(self.user_stack) - i for solver in self.solvers: self.repl_cmd(solver, f"(pop {pop_n})") self.user_stack = self.user_stack[:i] self.internal_stack = self.internal_stack[:i] break # (2) Fully apply any quantifiers. quant_asserts = set() for t in assertions: if hasattr(t, "qinst"): quant_asserts.update(t.qinst(assertions)) # (3) Then strip to just the new assertions & update the stacks assertions = assertions[len(self.user_stack[-1]):] self.user_stack.append(assertions.copy()) self.internal_stack.append(self.internal_stack[-1].copy()) # (4) collect declarations assertions = assertions + sorted(quant_asserts) variables = set() for t in assertions: if isinstance(t, SMTTerm): variables.update(t.variables()) new_lines = [] for v in sorted(variables): if v.smt_type is None: continue new_lines.append(f"(declare-fun {v.to_smtlib()} () {v.smt_type})") # (5) Convert assertions to strs new_lines.extend( t if isinstance(t, str) else f"(assert {t.to_smtlib()})" for t in assertions if not hasattr(t, "qinst")) # (6) Deduplicate lines already_added = set(self.internal_stack[-1]) really_new = [] for l in new_lines: if l in already_added: continue self.internal_stack[-1].append(l) already_added.add(l) really_new.append(l) for solver in self.solvers: self.repl_cmd(solver, "(push 1)") for l in really_new: for solver in self.solvers: self.repl_cmd(solver, l) # (7) ask the existing solvers to check satisfiability assert len(self.solvers) == 1 self.n_calls += 1 if self.n_calls % 5 == 0 and SATISPI.MAIN_SOLVER not in self.solvers: self.bring_up_solver(SATISPI.MAIN_SOLVER) for solver in self.solvers: self.repl_cmd(solver, "(check-sat)") # give the original one a chance solver = next(iter(self.solvers)) result = None try: result = self.solvers[solver].read_nonblocking(timeout=0.3) result += self.solvers[solver].readline() except pexpect.exceptions.TIMEOUT: # now try bringing up the other for solver in SOLVERS: if solver not in self.solvers: try: self.bring_up_solver(solver) self.repl_cmd(solver, "(check-sat)") except SMTInternalError: print("Got internal error bringing up", solver, "... ignoring") self.solvers[solver].terminate(True) self.solvers.pop(solver) # and then poll on both of them start_time = time.time() while result is None: if (time.time() - start_time) > self.growing_timeout: raise SMTInternalError() for solver in self.solvers: try: result = self.solvers[solver].read_nonblocking(timeout=0.0001) result += self.solvers[solver].readline() break except pexpect.exceptions.TIMEOUT: continue # print("Winner:", next(iter(self.solvers))) result = result.split()[-1] for other in self.solvers: if other != solver: self.solvers[other].terminate(True) self.solvers = dict({solver: self.solvers[solver]}) if result not in ("sat", "unsat"): print(result) raise NotImplementedError if result == "unsat": self.insert_pcache(key, None) return None model_line = f"(get-value ({' '.join(v.to_smtlib() for v in for_vars)}))" self.repl_cmd(solver, model_line) model = self.solvers[solver].readline() while not are_parens_balanced(model): model += self.solvers[solver].readline() if "(error" in model: # kill all the solvers print("Error in SMT solver; re-trying ...") for solver in self.solvers: self.solvers[solver].terminate(True) self.solvers = dict() self.bring_up_solver(SATISPI.MAIN_SOLVER) return self.model(og_assertions, og_vars) model = model_parser(model) self.insert_pcache(key, model) return model @staticmethod def to_file(assertions, filename): quant_lines = set() for t in assertions: if hasattr(t, "qinst"): quant_lines.update(t.qinst(assertions)) variables = set() for t in assertions + sorted(quant_lines): if isinstance(t, SMTTerm): variables.update(t.variables()) f = open(filename, "w") if SATISPI.LOGIC: f.write(f"(set-logic {SATISPI.LOGIC})\n") for v in sorted(variables): if v.smt_type is None: continue f.write(f"(declare-fun {v.to_smtlib()} () {v.smt_type})\n") # (5) Convert assertions to strs for t in assertions: if isinstance(t, str): f.write(f"{t}\n") else: f.write(f"(assert {t.to_smtlib()})\n") f.close() print(filename) def model_parser(string): string = string.strip() modelstr = string[string.index('('):] assert modelstr[-1] == ')' modelstr = modelstr[1:-1].strip() # now (e1 v1) (e2 v2) ... parts = split_parens(modelstr) model = dict() for part in parts: try: name, value = split_parens(part) except ValueError: print(part) print(string) raise ValueError if value.startswith("#b"): model[name] = int(value[2:], 2) elif value.startswith("#x"): model[name] = int(value[2:], 16) elif value == "true": model[name] = True elif value == "false": model[name] = False elif value.startswith("_ bv"): model[name] = int(value[len("_ bv"):].split()[0]) elif value == "???": model[name] = None else: print(part, value) assert value == "???" raise ValueError return model def split_parens(string): parts = [] while string: paren_count = 0 for i, c in enumerate(string): paren_count += (c == '(') paren_count -= (c == ')') if paren_count != 0: continue if c in (' ', ')'): if string[0] == '(': parts.append(string[1:i].strip()) else: parts.append(string[:i].strip()) string = string[i+1:].strip() break else: if string: parts.append(string) string = "" return parts def are_parens_balanced(string): count = 0 for c in string: count += (c == '(') count -= (c == ')') return count == 0 class SMTInternalError(Exception): pass