diff options
Diffstat (limited to 'python/dietc.py')
-rw-r--r-- | python/dietc.py | 141 |
1 files changed, 141 insertions, 0 deletions
diff --git a/python/dietc.py b/python/dietc.py new file mode 100644 index 0000000..b3a0326 --- /dev/null +++ b/python/dietc.py @@ -0,0 +1,141 @@ +import sys +from dataclasses import * +import re + +@dataclass +class PointerType: + base: any + +@dataclass +class ArrayType: + base: any + length: int + +@dataclass +class FunctionType: + return_type: any + params: [any] + variadic: bool + +@dataclass +class AggregateType: + fields: [any] + +@dataclass +class BasicType: + name: str + +@dataclass +class Function: + decl_line: str + body: [str] + + def locals(self, and_params=False): + assert not and_params + return [line[1] for line in self.body[:self.code_start()]] + + def code_start(self): + return next(i for i, l in enumerate(self.body) + if not l.line.startswith("\tType_")) + + def insert(self, i, string): + self.body.insert(i, Instruction(string)) + + def cfg(self): + assert self.body is not None + labels = {node[0]: i for i, node in enumerate(self.body) + if node[1] == ":"} + cfg = dict() + for i, node in enumerate(self.body): + if node[0] not in ("goto", "return"): + cfg[node] = [i + 1] + if node[0] in ("goto", "if"): + cfg[node] = [labels[node[-2]]] + return cfg + +@dataclass +class Instruction: + line: str + tokens: [str] = field(init=False) + def __post_init__(self): self.tokens = self.line.split() + def __getitem__(self, i): return self.tokens[i] + def lhs_type(self, program): + assert "=" in self.tokens + if self[0] == "*": + return self.program.object_types[self[1]].base + return self.program.object_types[self[0]] + +class Program: + def __init__(self, string): + lines = string.decode("utf-8").split("\n") + self.preamble = [] + self.types = dict() + self.object_types = dict() + while lines and not (lines[1:] and lines[1].startswith("\t")): + self.preamble.append(lines.pop(0)) + for line in self.preamble: + if not line.startswith("typedef"): continue + types = [t for t in line.split() + if any(t.startswith(k) for k in ("Type_", "Struct_", "Union_"))] + if " * " in line: + self.types[types[1]] = PointerType(self.types[types[0]]) + elif " [ " in line: + arrlen = line.split("[")[1].split("]")[0].strip() + self.types[types[1]] = ArrayType(self.types[types[0]], arrlen) + elif " ( " in line: + self.types[types[1]] = FunctionType( + self.types[types[0]], [self.types[t] for t in types[2:]], + "..." in line) + elif " struct " in line or " union " in line: + if " { " in line: + parts = line.split() + # typedef struct ___ + # { Type_ [ident] ; Type_ [ident] ; ... } Type_ ; + field_names = [parts[i] + for i in range(5, len(parts) - 1, 3)] + field_types = [self.types[t] for t in types[:-1]] + assert len(field_names) == len(field_types) + self.types[types[-1]].fields \ + = list(zip(field_names, field_types)) + else: + assert types[-1] not in self.types + self.types[types[-1]] = AggregateType([]) + else: + self.types[types[0]] = BasicType(line.split()[1:-2]) + # assign object types for globals + for line in self.preamble: + if not line.startswith("extern"): continue + parts = line.split() + assert parts[-2] not in self.object_types + self.object_types[parts[-2]] = self.types[parts[-3]] + # recover functions & local object types + self.functions = [] + while lines: + start_line = lines.pop(0) + if not start_line.strip(): continue + params = start_line.split("(")[1].split(")")[0].split(",") + for param in params: + if not param.strip(): continue + if "..." in param: continue + type_, name = param.split() + self.object_types[name] = self.types[type_] + body = [] + while lines[0] != "}": + body_line = lines.pop(0) + if body_line.startswith("\tType_"): + parts = body_line.split() + assert parts[-2] not in self.object_types + self.object_types[parts[1]] = self.types[parts[0]] + body.append(Instruction(body_line)) + lines.pop(0) + self.functions.append(Function(start_line, body)) + + def to_c(self): + c = "\n".join(self.preamble) + "\n" + for function in self.functions: + c += function.decl_line + "\n" + c += "\n".join([l.line for l in function.body]) + "\n}\n" + return c + + def print(self): + print(self.to_c().strip()) |