From 2b8985608d33abaae7b201a008e292cbbe2167ef Mon Sep 17 00:00:00 2001 From: Matthew Sotoudeh Date: Thu, 27 Jul 2023 14:26:33 -0700 Subject: add automated refcounting pass --- python/examples/refcounting/dietpass | 76 ++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100755 python/examples/refcounting/dietpass (limited to 'python/examples/refcounting/dietpass') diff --git a/python/examples/refcounting/dietpass b/python/examples/refcounting/dietpass new file mode 100755 index 0000000..9a7b0e8 --- /dev/null +++ b/python/examples/refcounting/dietpass @@ -0,0 +1,76 @@ +#!/bin/python3 +import os +import sys +from dietc import * + +def eprint(*args): print(*args, file=sys.stderr) + +def main(): + prog = Program(open(sys.argv[1], "rb").read()) + prog.preamble.append("#include ") + for function in prog.functions: + # (1) rewrite calloc, free, ... calls to refcount_* + for name in ("calloc", "free", "malloc", "realloc"): + function.body = [instr.replace_token(name, f"refcount_{name}") + for instr in function.body] + new_body = function.body[:function.code_start()] + # (2a) initialize all pointer expressions to 0 + for local in function.locals(): + for p in ptr_exprs(local, prog.object_types[local]): + new_body.append(Instruction(f"{p} = 0 ;")) + # (2b) whenever a pointer is written to, call refcount_write + for instr in function.body[function.code_start():]: + if instr[0] == "*" and instr[2] == "=": + refcount_pairs = [] + ty = prog.object_types[instr[1]].base + for expr in ptr_exprs(f"(*{instr[1]})", ty): + backup = f"_refcount_ptr_backup_{count()}" + new_body.append(Instruction(f"void * {backup} = {expr} ;")) + refcount_pairs.append((backup, expr)) + new_body.append(instr) + for old, new in refcount_pairs: + new_body.append(Instruction(f"refcount_write({new},{old});")) + elif instr[1] == "=": + refcount_pairs = [] + ty = prog.object_types[instr[0]] + for expr in ptr_exprs(instr[0], ty): + backup = f"_refcount_ptr_backup_{count()}" + new_body.append(Instruction(f"void * {backup} = {expr} ;")) + refcount_pairs.append((backup, expr)) + new_body.append(instr) + for old, new in refcount_pairs: + new_body.append(Instruction(f"refcount_write({new},{old});")) + elif instr[0] == "return": + for local in function.locals(): + if local == instr[1]: continue + for expr in ptr_exprs(local, prog.object_types[local]): + new_body.append(Instruction(f"refcount_write(0,({expr}));")) + if instr[1] != ";": + for expr in ptr_exprs(instr[1], prog.object_types[instr[1]]): + new_body.append(Instruction(f"refcount_returning({expr});")) + new_body.append(instr) + else: + new_body.append(instr) + function.body = new_body + prog.print() + +# given an object and its type, returns a list of l-value expression strings, +# representing all of the pointers in that object +def ptr_exprs(name, ty): + if isinstance(ty, PointerType): + return [name] + elif isinstance(ty, AggregateType): + exprs = [] + for fname, ftype in ty.fields: + exprs.extend(ptr_exprs(f"{name}.{fname}", ftype)) + return exprs + elif isinstance(ty, (BasicType, FunctionType)): + return [] + elif isinstance(ty, ArrayType): + # TODO + return [] + else: + print(ty.classify(), file=sys.stderr) + raise NotImplementedError + +main() -- cgit v1.2.3