summaryrefslogtreecommitdiff
path: root/example/program.py
blob: b3e6afbb69dc7e47ddfd7e750de1de678ea9042b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from satispi import *

class Program:
    def __init__(self, instructions):
        self.instructions = instructions
        self.registers = set()
        for pc in self.instructions:
            if pc[0] == "label": continue
            if pc[0] == "ite":
                if isinstance(pc[1], str):
                    self.registers.add(pc[1])
            else:
                self.registers.update(a for a in pc[1:] if isinstance(a, str))
        self.registers = sorted(self.registers)

    @staticmethod
    def from_file(path):
        lines = [l.strip() for l in open(path, "r").readlines()]
        lines = [l for l in lines if l and not l.startswith(";")]
        instructions = []
        for l in lines:
            opcode, *args = l.split()
            if opcode != "label":
                args = [int(a[1:]) if a[0] == "$" else a for a in args]
            instructions.append((opcode, *args))
        return Program(instructions)

    def get_op(self, pcidx):
        return self.instructions[pcidx]

    def label_to_pcidx(self, label):
        for i, pc in enumerate(self.instructions):
            if pc == ("label", label):
                return i
        raise NotImplementedError

class State:
    def __init__(self, program, pcidx, smtlib, parent):
        self.program = program
        self.pcidx = pcidx
        self.smtlib = smtlib.copy()
        self.parent = parent
        self.idx = (parent.idx + 1) if parent is not None else 0

    def step(self):
        smtlib = self.smtlib
        opcode, *args = self.program.get_op(self.pcidx)
        if opcode == "ite": # ite cond tlabel flabel
            cond        = self.to_term(args[0])
            true_idx    = self.program.label_to_pcidx(args[1])
            false_idx   = self.program.label_to_pcidx(args[2])
            delta       = self.nops({})
            true_delta  = Term("and", delta, Term("distinct", cond, Literal("(_ bv0 32)")))
            false_delta = Term("and", delta, Term("=", cond, Literal("(_ bv0 32)")))
            if args[1] == args[2]:
                return [State(self.program, true_idx, smtlib + [delta], self)]
            return [State(self.program, true_idx, smtlib + [true_delta], self),
                    State(self.program, false_idx, smtlib + [false_delta], self)]

        new_pcidx = self.pcidx + 1
        if opcode == "label":
            delta = self.nops({})
        elif opcode == "end":
            delta = self.nops({})
            new_pcidx = self.pcidx
        elif opcode == "load": # load dstreg addr
            to_reg = self.nto_term(args[0])
            value  = Term("select", self.heap(), self.to_term(args[1]))
            delta  = self.nops({args[0]})
            delta  = Term("and", delta, Term("=", to_reg, value))
        elif opcode == "store": # store addr val
            addr     = self.to_term(args[0])
            value    = self.to_term(args[1])
            new_heap = Term("store", self.heap(), addr, value)
            delta    = self.nops({"@heap"})
            delta    = Term("and", delta, Term("=", self.nheap(), new_heap))
        elif opcode in ("add", "sub", "eq", "xor", "mul"): # op reg delta
            op = {"add": "bvadd", "sub": "bvsub", "eq": "=",
                  "xor": "bvxor", "mul": "bvmul"}[opcode]
            old   = self.to_term(args[0])
            val   = self.to_term(args[1])
            new   = Term(op, old, val)
            if op == "=":
                new = Term("ite", new, Literal("(_ bv1 32)"), Literal("(_ bv0 32)"))
            delta = self.nops({args[0]})
            delta = Term("and", delta, Term("=", self.nreg(args[0]), new))
        elif opcode == "mov": # mov dst src
            val   = self.to_term(args[1])
            delta = self.nops({args[0]})
            delta = Term("and", delta, Term("=", self.nreg(args[0]), val))
        else: print(opcode); raise NotImplementedError
        return [State(self.program, new_pcidx, self.smtlib + [delta], self)]

    def has_model(self):
        return Solver.singleton().model(self.smtlib) is not None

    def heap(self):
        return Variable(f"heap_{self.idx}",
                        "(Array (_ BitVec 32) (_ BitVec 32))")

    def reg(self, reg):
        return Variable(f"reg_{reg}_{self.idx}", "(_ BitVec 32)")

    def to_term(self, arg):
        if isinstance(arg, int):
            return Literal(f"(_ bv{arg} 32)")
        return self.reg(arg)

    def nheap(self):
        return Variable(f"heap_{self.idx + 1}",
                        "(Array (_ BitVec 32) (_ BitVec 32))")

    def nreg(self, reg):
        return Variable(f"reg_{reg}_{self.idx + 1}", "(_ BitVec 32)")

    def nto_term(self, arg):
        if isinstance(arg, int):
            return Literal(f"(_ bv{arg} 32)")
        return self.nreg(arg)

    def nops(self, except_for=set()):
        term = Literal("true")
        for reg in self.program.registers:
            if reg in except_for: continue
            term = Term("and", term,
                        Term("=", self.nreg(reg), self.reg(reg)))
        if "@heap" not in except_for:
            term = Term("and", term,
                        Term("=", self.nheap(), self.heap()))
        return term
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback