1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
#!/bin/python3
import os
import sys
from dietc import *
def eprint(*args): print(*args, file=sys.stderr)
def main():
prog = Program(open(sys.argv[1], "rb").read())
prog.preamble.append("#include <refcounting.h>")
for function in prog.functions:
# (1) rewrite calloc, free, ... calls to refcount_*
for name in ("calloc", "free", "malloc", "realloc"):
function.body = [instr.replace_token(name, f"refcount_{name}")
for instr in function.body]
new_body = function.body[:function.code_start()]
# (2a) initialize all pointer expressions to 0
for local in function.locals():
for p in ptr_exprs(local, prog.object_types[local]):
new_body.append(Instruction(f"{p} = 0 ;"))
# (2b) whenever a pointer is written to, call refcount_write
for instr in function.body[function.code_start():]:
if instr[0] == "*" and instr[2] == "=":
refcount_pairs = []
ty = prog.object_types[instr[1]].base
for expr in ptr_exprs(f"(*{instr[1]})", ty):
backup = f"_refcount_ptr_backup_{count()}"
new_body.append(Instruction(f"void * {backup} = {expr} ;"))
refcount_pairs.append((backup, expr))
new_body.append(instr)
for old, new in refcount_pairs:
new_body.append(Instruction(f"refcount_write({new},{old});"))
elif instr[1] == "=":
refcount_pairs = []
ty = prog.object_types[instr[0]]
for expr in ptr_exprs(instr[0], ty):
backup = f"_refcount_ptr_backup_{count()}"
new_body.append(Instruction(f"void * {backup} = {expr} ;"))
refcount_pairs.append((backup, expr))
new_body.append(instr)
for old, new in refcount_pairs:
new_body.append(Instruction(f"refcount_write({new},{old});"))
elif instr[0] == "return":
for local in function.locals():
if local == instr[1]: continue
for expr in ptr_exprs(local, prog.object_types[local]):
new_body.append(Instruction(f"refcount_write(0,({expr}));"))
if instr[1] != ";":
for expr in ptr_exprs(instr[1], prog.object_types[instr[1]]):
new_body.append(Instruction(f"refcount_returning({expr});"))
new_body.append(instr)
else:
new_body.append(instr)
function.body = new_body
prog.print()
# given an object and its type, returns a list of l-value expression strings,
# representing all of the pointers in that object
def ptr_exprs(name, ty):
if isinstance(ty, PointerType):
return [name]
elif isinstance(ty, AggregateType):
exprs = []
for fname, ftype in ty.fields:
exprs.extend(ptr_exprs(f"{name}.{fname}", ftype))
return exprs
elif isinstance(ty, (BasicType, FunctionType)):
return []
elif isinstance(ty, ArrayType):
result = []
for i in range(int(ty.length)):
result.extend(ptr_exprs(f"(({name})[{i}])", ty.base))
return result
else:
print(ty.classify(), file=sys.stderr)
raise NotImplementedError
main()
|