diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/README.txt | 19 | ||||
-rw-r--r-- | python/dietc.py | 141 | ||||
-rw-r--r-- | python/examples/test_files/zero_init.c | 20 | ||||
-rwxr-xr-x | python/examples/zero_init | 13 |
4 files changed, 193 insertions, 0 deletions
diff --git a/python/README.txt b/python/README.txt new file mode 100644 index 0000000..8129c48 --- /dev/null +++ b/python/README.txt @@ -0,0 +1,19 @@ +Python library for building dietcc passes + +Example of using a pass: + + The examples/zero_init pass zero-initializes all local variables. It can be + run by passing a --dietc-pass flag to dietcc: + + $ /path/to/dietcc examples/test_files/zero_init.c -Wuninitialized + ... warning: _xyz_3 is used ininitialized ... + $ ./a.out + Foo return value: 0 + Foo return value: 32765 + Foo return value: 32765 + $ /path/to/dietcc examples/test_files/zero_init.c --dietc-pass $PWD/examples/zero_init -Wuninitialized + [no warnings] + $ ./a.out + Foo return value: 0 + Foo return value: 0 + Foo return value: 0 diff --git a/python/dietc.py b/python/dietc.py new file mode 100644 index 0000000..b3a0326 --- /dev/null +++ b/python/dietc.py @@ -0,0 +1,141 @@ +import sys +from dataclasses import * +import re + +@dataclass +class PointerType: + base: any + +@dataclass +class ArrayType: + base: any + length: int + +@dataclass +class FunctionType: + return_type: any + params: [any] + variadic: bool + +@dataclass +class AggregateType: + fields: [any] + +@dataclass +class BasicType: + name: str + +@dataclass +class Function: + decl_line: str + body: [str] + + def locals(self, and_params=False): + assert not and_params + return [line[1] for line in self.body[:self.code_start()]] + + def code_start(self): + return next(i for i, l in enumerate(self.body) + if not l.line.startswith("\tType_")) + + def insert(self, i, string): + self.body.insert(i, Instruction(string)) + + def cfg(self): + assert self.body is not None + labels = {node[0]: i for i, node in enumerate(self.body) + if node[1] == ":"} + cfg = dict() + for i, node in enumerate(self.body): + if node[0] not in ("goto", "return"): + cfg[node] = [i + 1] + if node[0] in ("goto", "if"): + cfg[node] = [labels[node[-2]]] + return cfg + +@dataclass +class Instruction: + line: str + tokens: [str] = field(init=False) + def __post_init__(self): self.tokens = self.line.split() + def __getitem__(self, i): return self.tokens[i] + def lhs_type(self, program): + assert "=" in self.tokens + if self[0] == "*": + return self.program.object_types[self[1]].base + return self.program.object_types[self[0]] + +class Program: + def __init__(self, string): + lines = string.decode("utf-8").split("\n") + self.preamble = [] + self.types = dict() + self.object_types = dict() + while lines and not (lines[1:] and lines[1].startswith("\t")): + self.preamble.append(lines.pop(0)) + for line in self.preamble: + if not line.startswith("typedef"): continue + types = [t for t in line.split() + if any(t.startswith(k) for k in ("Type_", "Struct_", "Union_"))] + if " * " in line: + self.types[types[1]] = PointerType(self.types[types[0]]) + elif " [ " in line: + arrlen = line.split("[")[1].split("]")[0].strip() + self.types[types[1]] = ArrayType(self.types[types[0]], arrlen) + elif " ( " in line: + self.types[types[1]] = FunctionType( + self.types[types[0]], [self.types[t] for t in types[2:]], + "..." in line) + elif " struct " in line or " union " in line: + if " { " in line: + parts = line.split() + # typedef struct ___ + # { Type_ [ident] ; Type_ [ident] ; ... } Type_ ; + field_names = [parts[i] + for i in range(5, len(parts) - 1, 3)] + field_types = [self.types[t] for t in types[:-1]] + assert len(field_names) == len(field_types) + self.types[types[-1]].fields \ + = list(zip(field_names, field_types)) + else: + assert types[-1] not in self.types + self.types[types[-1]] = AggregateType([]) + else: + self.types[types[0]] = BasicType(line.split()[1:-2]) + # assign object types for globals + for line in self.preamble: + if not line.startswith("extern"): continue + parts = line.split() + assert parts[-2] not in self.object_types + self.object_types[parts[-2]] = self.types[parts[-3]] + # recover functions & local object types + self.functions = [] + while lines: + start_line = lines.pop(0) + if not start_line.strip(): continue + params = start_line.split("(")[1].split(")")[0].split(",") + for param in params: + if not param.strip(): continue + if "..." in param: continue + type_, name = param.split() + self.object_types[name] = self.types[type_] + body = [] + while lines[0] != "}": + body_line = lines.pop(0) + if body_line.startswith("\tType_"): + parts = body_line.split() + assert parts[-2] not in self.object_types + self.object_types[parts[1]] = self.types[parts[0]] + body.append(Instruction(body_line)) + lines.pop(0) + self.functions.append(Function(start_line, body)) + + def to_c(self): + c = "\n".join(self.preamble) + "\n" + for function in self.functions: + c += function.decl_line + "\n" + c += "\n".join([l.line for l in function.body]) + "\n}\n" + return c + + def print(self): + print(self.to_c().strip()) diff --git a/python/examples/test_files/zero_init.c b/python/examples/test_files/zero_init.c new file mode 100644 index 0000000..5db109a --- /dev/null +++ b/python/examples/test_files/zero_init.c @@ -0,0 +1,20 @@ +#include <stdio.h> + +int set_xyz() { + int xyz = 1; + return xyz; +} + +int foo() { + int xyz; + return xyz; +} + +int main() { + printf("Foo return value: %d\n", foo()); + set_xyz(); + printf("Foo return value: %d\n", foo()); + set_xyz(); + printf("Foo return value: %d\n", foo()); + return 0; +} diff --git a/python/examples/zero_init b/python/examples/zero_init new file mode 100755 index 0000000..e2ab92e --- /dev/null +++ b/python/examples/zero_init @@ -0,0 +1,13 @@ +#!/bin/python3 +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +import dietc + +prog = dietc.Program(open(sys.argv[1], "rb").read()) +for function in prog.functions: + fn_locals = function.locals() + start_i = function.code_start() + for local in fn_locals: + function.insert(start_i, f"\tMEMZERO ( {local} ) ;") +prog.print() |