import sys from dataclasses import * import re # TODO: these hashes are unsafe @dataclass class PointerType: base: any def __hash__(self): return id(self) @dataclass class ArrayType: base: any length: int def __hash__(self): return id(self) @dataclass class FunctionType: return_type: any params: [any] variadic: bool def __hash__(self): return id(self) @dataclass class AggregateType: fields: [any] def __hash__(self): return id(self) @dataclass class BasicType: name: str def __hash__(self): return id(self) @dataclass class Function: decl_line: str body: [str] def locals(self, and_params=False): assert not and_params return [line[1] for line in self.body[:self.code_start()]] def code_start(self): return next(i for i, l in enumerate(self.body) if not l[0].startswith("Type_")) def insert(self, i, string): self.body.insert(i, Instruction(string)) def cfg(self): assert self.body is not None labels = {node[0]: i for i, node in enumerate(self.body) if node[1] == ":"} cfg = dict() for i, node in enumerate(self.body): if node[0] not in ("goto", "return"): cfg[node] = [i + 1] if node[0] in ("goto", "if"): cfg[node] = [labels[node[-2]]] return cfg @dataclass class Instruction: line: str tokens: [str] = field(init=False) def __post_init__(self): self.tokens = self.line.split() def __getitem__(self, i): return self.tokens[i] def lhs_type(self, program): assert "=" in self.tokens if self[0] == "*": return program.object_types[self[1]].base return program.object_types[self[0]] def replace_token(self, old, new): self = Instruction(self.line) self.tokens = [t if t != old else new for t in self.tokens] self.line = ' '.join(self.tokens) return self class Program: def __init__(self, string): lines = string.decode("utf-8").split("\n") self.preamble = [] self.types = dict() self.object_types = dict() while lines and not (lines[1:] and lines[1].startswith("\t")): self.preamble.append(lines.pop(0)) for line in self.preamble: if not (line.startswith("typedef") or line.startswith("struct") or line.startswith("union")): continue types = [t for t in line.split() if any(t.startswith(k) for k in ("Type_", "Struct_", "Union_"))] if " * " in line: self.types[types[1]] = PointerType(self.types[types[0]]) elif " [ " in line: arrlen = line.split("[")[1].split("]")[0].strip() self.types[types[1]] = ArrayType(self.types[types[0]], arrlen) elif " ( " in line: self.types[types[1]] = FunctionType( self.types[types[0]], [self.types[t] for t in types[2:]], "..." in line) elif "struct " in line or "union " in line: if " { " in line: parts = line.split() # typedef struct ___ # { Type_ [ident] ; Type_ [ident] ; ... } Type_ ; field_names = [parts[i] for i in range(4, len(parts) - 1, 3)] field_types = [self.types[t] for t in types[1:]] assert len(field_names) == len(field_types) name = "Type_" + types[0].split("_")[1] self.types[name].fields \ = list(zip(field_names, field_types)) else: assert types[-1] not in self.types self.types[types[-1]] = AggregateType([]) else: self.types[types[0]] = BasicType(' '.join(line.split()[1:-2])) self.type_to_name = {v: k for k, v in self.types.items()} # assign object types for globals for line in self.preamble: if not line.startswith("extern"): continue parts = line.split() assert parts[-2] not in self.object_types self.object_types[parts[-2]] = self.types[parts[-3]] # recover functions & local object types self.functions = [] while lines: start_line = lines.pop(0) if not start_line.strip(): continue params = start_line.split("(")[1].split(")")[0].split(",") for param in params: if not param.strip(): continue if "..." in param: continue type_, name = param.split() self.object_types[name] = self.types[type_] body = [] while lines[0] != "}": body_line = lines.pop(0) if body_line.startswith("\tType_"): parts = body_line.split() assert parts[-2] not in self.object_types self.object_types[parts[1]] = self.types[parts[0]] body.append(Instruction(body_line)) lines.pop(0) self.functions.append(Function(start_line, body)) def to_c(self): c = "\n".join(self.preamble) + "\n" for function in self.functions: c += function.decl_line + "\n" c += "\n".join([l.line for l in function.body]) + "\n}\n" return c def print(self): print(self.to_c().strip()) def count(): if not hasattr(count, "COUNT"): count.COUNT = 0 count.COUNT += 1 return count.COUNT