summaryrefslogtreecommitdiff
path: root/python/examples/refcounting/dietpass
blob: 9d6a77f0c56f90d8c5b508fd128730984eb2fc08 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/bin/python3
import os
import sys
from dietc import *

def eprint(*args): print(*args, file=sys.stderr)

def main():
    prog = Program(open(sys.argv[1], "rb").read())
    prog.preamble.append("#include <refcounting.h>")
    for function in prog.functions:
        # (1) rewrite calloc, free, ... calls to refcount_*
        for name in ("calloc", "free", "malloc", "realloc"):
            function.body = [instr.replace_token(name, f"refcount_{name}")
                             for instr in function.body]
        new_body = function.body[:function.code_start()]
        # (2a) initialize all pointer expressions to 0
        for local in function.locals():
            for p in ptr_exprs(local, prog.object_types[local]):
                new_body.append(Instruction(f"{p} = 0 ;"))
        # (2b) whenever a pointer is written to, call refcount_write
        for instr in function.body[function.code_start():]:
            if instr[0] == "*" and instr[2] == "=":
                refcount_pairs = []
                ty = prog.object_types[instr[1]].base
                for expr in ptr_exprs(f"(*{instr[1]})", ty):
                    backup = f"_refcount_ptr_backup_{count()}"
                    new_body.append(Instruction(f"void * {backup} = {expr} ;"))
                    refcount_pairs.append((backup, expr))
                new_body.append(instr)
                for old, new in refcount_pairs:
                    new_body.append(Instruction(f"refcount_write({new},{old});"))
            elif instr[1] == "=":
                refcount_pairs = []
                ty = prog.object_types[instr[0]]
                for expr in ptr_exprs(instr[0], ty):
                    backup = f"_refcount_ptr_backup_{count()}"
                    new_body.append(Instruction(f"void * {backup} = {expr} ;"))
                    refcount_pairs.append((backup, expr))
                new_body.append(instr)
                for old, new in refcount_pairs:
                    new_body.append(Instruction(f"refcount_write({new},{old});"))
            elif instr[0] == "return":
                for local in function.locals():
                    if local == instr[1]: continue
                    for expr in ptr_exprs(local, prog.object_types[local]):
                        new_body.append(Instruction(f"refcount_write(0,({expr}));"))
                if instr[1] != ";":
                    for expr in ptr_exprs(instr[1], prog.object_types[instr[1]]):
                        new_body.append(Instruction(f"refcount_returning({expr});"))
                new_body.append(instr)
            else:
                new_body.append(instr)
        function.body = new_body
    prog.print()

# given an object and its type, returns a list of l-value expression strings,
# representing all of the pointers in that object
def ptr_exprs(name, ty):
    if isinstance(ty, PointerType):
        return [name]
    elif isinstance(ty, AggregateType):
        exprs = []
        for fname, ftype in ty.fields:
            exprs.extend(ptr_exprs(f"{name}.{fname}", ftype))
        return exprs
    elif isinstance(ty, (BasicType, FunctionType)):
        return []
    elif isinstance(ty, ArrayType):
        result = []
        for i in range(int(ty.length)):
            result.extend(ptr_exprs(f"(({name})[{i}])", ty.base))
        return result
    else:
        print(ty.classify(), file=sys.stderr)
        raise NotImplementedError

main()
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback