#!/bin/python3 import os import sys from dietc import * def eprint(*args): print(*args, file=sys.stderr) def main(): prog = Program(open(sys.argv[1], "rb").read()) prog.preamble.append("#include ") for function in prog.functions: # (1) rewrite calloc, free, ... calls to refcount_* for name in ("calloc", "free", "malloc", "realloc"): function.body = [instr.replace_token(name, f"refcount_{name}") for instr in function.body] new_body = function.body[:function.code_start()] # (2a) initialize all pointer expressions to 0 for local in function.locals(): for p in ptr_exprs(local, prog.object_types[local]): new_body.append(Instruction(f"{p} = 0 ;")) # (2b) whenever a pointer is written to, call refcount_write for instr in function.body[function.code_start():]: if instr[0] == "*" and instr[2] == "=": refcount_pairs = [] ty = prog.object_types[instr[1]].base for expr in ptr_exprs(f"(*{instr[1]})", ty): backup = f"_refcount_ptr_backup_{count()}" new_body.append(Instruction(f"void * {backup} = {expr} ;")) refcount_pairs.append((backup, expr)) new_body.append(instr) for old, new in refcount_pairs: new_body.append(Instruction(f"refcount_write({new},{old});")) elif instr[1] == "=": refcount_pairs = [] ty = prog.object_types[instr[0]] for expr in ptr_exprs(instr[0], ty): backup = f"_refcount_ptr_backup_{count()}" new_body.append(Instruction(f"void * {backup} = {expr} ;")) refcount_pairs.append((backup, expr)) new_body.append(instr) for old, new in refcount_pairs: new_body.append(Instruction(f"refcount_write({new},{old});")) elif instr[0] == "return": for local in function.locals(): if local == instr[1]: continue for expr in ptr_exprs(local, prog.object_types[local]): new_body.append(Instruction(f"refcount_write(0,({expr}));")) if instr[1] != ";": for expr in ptr_exprs(instr[1], prog.object_types[instr[1]]): new_body.append(Instruction(f"refcount_returning({expr});")) new_body.append(instr) else: new_body.append(instr) function.body = new_body prog.print() # given an object and its type, returns a list of l-value expression strings, # representing all of the pointers in that object def ptr_exprs(name, ty): if isinstance(ty, PointerType): return [name] elif isinstance(ty, AggregateType): exprs = [] for fname, ftype in ty.fields: exprs.extend(ptr_exprs(f"{name}.{fname}", ftype)) return exprs elif isinstance(ty, (BasicType, FunctionType)): return [] elif isinstance(ty, ArrayType): result = [] for i in range(int(ty.length)): result.extend(ptr_exprs(f"(({name})[{i}])", ty.base)) return result else: print(ty.classify(), file=sys.stderr) raise NotImplementedError main()