from satispi import * class Program: def __init__(self, instructions): self.instructions = instructions self.registers = set() for pc in self.instructions: if pc[0] == "label": continue if pc[0] == "ite": if isinstance(pc[1], str): self.registers.add(pc[1]) else: self.registers.update(a for a in pc[1:] if isinstance(a, str)) self.registers = sorted(self.registers) @staticmethod def from_file(path): lines = [l.strip() for l in open(path, "r").readlines()] lines = [l for l in lines if l and not l.startswith(";")] instructions = [] for l in lines: opcode, *args = l.split() if opcode != "label": args = [int(a[1:]) if a[0] == "$" else a for a in args] instructions.append((opcode, *args)) return Program(instructions) def get_op(self, pcidx): return self.instructions[pcidx] def label_to_pcidx(self, label): for i, pc in enumerate(self.instructions): if pc == ("label", label): return i raise NotImplementedError class State: def __init__(self, program, pcidx, smtlib, parent): self.program = program self.pcidx = pcidx self.smtlib = smtlib.copy() self.parent = parent self.idx = (parent.idx + 1) if parent is not None else 0 def step(self): smtlib = self.smtlib opcode, *args = self.program.get_op(self.pcidx) if opcode == "ite": # ite cond tlabel flabel cond = self.to_term(args[0]) true_idx = self.program.label_to_pcidx(args[1]) false_idx = self.program.label_to_pcidx(args[2]) delta = self.nops({}) true_delta = Term("and", delta, Term("distinct", cond, Literal("(_ bv0 32)"))) false_delta = Term("and", delta, Term("=", cond, Literal("(_ bv0 32)"))) if args[1] == args[2]: return [State(self.program, true_idx, smtlib + [delta], self)] return [State(self.program, true_idx, smtlib + [true_delta], self), State(self.program, false_idx, smtlib + [false_delta], self)] new_pcidx = self.pcidx + 1 if opcode == "label": delta = self.nops({}) elif opcode == "end": delta = self.nops({}) new_pcidx = self.pcidx elif opcode == "load": # load dstreg addr to_reg = self.nto_term(args[0]) value = Term("select", self.heap(), self.to_term(args[1])) delta = self.nops({args[0]}) delta = Term("and", delta, Term("=", to_reg, value)) elif opcode == "store": # store addr val addr = self.to_term(args[0]) value = self.to_term(args[1]) new_heap = Term("store", self.heap(), addr, value) delta = self.nops({"@heap"}) delta = Term("and", delta, Term("=", self.nheap(), new_heap)) elif opcode in ("add", "sub", "eq", "xor", "mul"): # op reg delta op = {"add": "bvadd", "sub": "bvsub", "eq": "=", "xor": "bvxor", "mul": "bvmul"}[opcode] old = self.to_term(args[0]) val = self.to_term(args[1]) new = Term(op, old, val) if op == "=": new = Term("ite", new, Literal("(_ bv1 32)"), Literal("(_ bv0 32)")) delta = self.nops({args[0]}) delta = Term("and", delta, Term("=", self.nreg(args[0]), new)) elif opcode == "mov": # mov dst src val = self.to_term(args[1]) delta = self.nops({args[0]}) delta = Term("and", delta, Term("=", self.nreg(args[0]), val)) else: print(opcode); raise NotImplementedError return [State(self.program, new_pcidx, self.smtlib + [delta], self)] def has_model(self): return Solver.singleton().model(self.smtlib) is not None def heap(self): return Variable(f"heap_{self.idx}", "(Array (_ BitVec 32) (_ BitVec 32))") def reg(self, reg): return Variable(f"reg_{reg}_{self.idx}", "(_ BitVec 32)") def to_term(self, arg): if isinstance(arg, int): return Literal(f"(_ bv{arg} 32)") return self.reg(arg) def nheap(self): return Variable(f"heap_{self.idx + 1}", "(Array (_ BitVec 32) (_ BitVec 32))") def nreg(self, reg): return Variable(f"reg_{reg}_{self.idx + 1}", "(_ BitVec 32)") def nto_term(self, arg): if isinstance(arg, int): return Literal(f"(_ bv{arg} 32)") return self.nreg(arg) def nops(self, except_for=set()): term = Literal("true") for reg in self.program.registers: if reg in except_for: continue term = Term("and", term, Term("=", self.nreg(reg), self.reg(reg))) if "@heap" not in except_for: term = Term("and", term, Term("=", self.nheap(), self.heap())) return term