import tempfile import subprocess import hashlib from glob import glob import pathlib import struct DIR = pathlib.Path(__file__).parent.resolve() class Parser: def __init__(self, parser_dir): assert parser_dir and parser_dir != '/' files = sorted([f"{parser_dir}/grammar.earlpy", *glob(f"{parser_dir}/*.c"), f"{DIR}/parser.c", __file__]) if f"{parser_dir}/parser.c" in files: files.remove(f"{parser_dir}/parser.c") hashes = ' '.join( hashlib.sha256(b''.join(open(f, "rb").readlines())).hexdigest() for f in files) already_built = False lex_path = f"{parser_dir}/parser.l" if glob(lex_path) and glob(f"{parser_dir}/parser"): if open(lex_path, "r").readline()[3:][:-3].strip() == hashes: already_built = True lines = self.parse_grammar(f"{parser_dir}/grammar.earlpy") if not already_built: if glob(f"{parser_dir}/parser"): subprocess.run(f"rm {parser_dir}/parser", shell=True) with open(f"{parser_dir}/parser.l", "w") as f: f.write(f"/* {hashes} */\n") for line in lines: f.write(line + "\n") f.write(open(f"{DIR}/parser.c", "r").read()) for path in glob(f"{parser_dir}/*.c"): if path == f"{parser_dir}/parser.c": continue f.write(open(path, "r").read()) res = subprocess.run(f"flex -o {parser_dir}/parser.c {parser_dir}/parser.l", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if res.returncode: print(res.stderr.decode("utf-8")) assert res.returncode == 0 res = subprocess.run(f"gcc -g -O3 {parser_dir}/parser.c -ljemalloc -o {parser_dir}/parser", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if res.returncode: print(res.stderr.decode("utf-8")) assert res.returncode == 0 self.parser = f"{parser_dir}/parser" def parse_string(self, string): with tempfile.NamedTemporaryFile() as f: f.write(string.encode("utf-8")) f.flush() result = self.parse_file(f.name) f.close() return result def parse_file(self, path): res = subprocess.run([self.parser, path], stdout=subprocess.PIPE, stderr=subprocess.PIPE) if res.returncode: print("FAIL:", res.stderr.decode("utf-8")) raise ValueError contents = open(path, "r").read() offset_to_line = dict() line = 1 for i, c in enumerate(open(path, "rb").read()): offset_to_line[i] = line if c == '\n' or chr(c) == '\n': line += 1 n_tokens, = struct.unpack("Q", res.stdout[:8]) # symbol id, start idx, length tokens = list(struct.iter_unpack("QQQ", res.stdout[8:8+(8*3*n_tokens)])) tokens = [Token(self.id_to_symbol[symbol], contents[offset:offset+length], offset_to_line[offset], path) for (symbol, offset, length) in tokens] # production id nodes = [t[0] for t in struct.iter_unpack("Q", res.stdout[8+(8*3*n_tokens):])] # print(nodes) # REPARSE the nodes root = Node(self.productions[nodes[0]][1], self.productions[nodes[0]][0]) nodes.pop(0) stack = [root] while stack: node = stack[-1] if (isinstance(node, Token) or len(node.production) == len(node.contents)): stack.pop() if stack: stack[-1].contents.append(node) else: symbol = node.production[len(node.contents)] if symbol.kind == "nonterm": prod_id = nodes.pop(0) stack.append(Node(self.productions[prod_id][1], self.productions[prod_id][0])) else: stack.append(tokens.pop(0)) return root def parse_grammar(self, grammar_file): grammar = open(grammar_file, "r") # (0) PARSE symbols from the grammar file self.name_to_symbol = dict() ordered_symbols = [] last_symbol = None for line in grammar: if line.strip().startswith("#"): continue elif line[0] in ' \t': last_symbol.process_subline(line.strip()) elif line.strip(): last_symbol = Symbol(line) self.name_to_symbol[last_symbol.name] = last_symbol ordered_symbols.append(last_symbol) # We allow mixing of concrete tokens and symbol names in the nonterminal # patterns; this undoes that. # (1a) map each concrete token to the list it belongs to concrete_to_symbol = dict() for symbol in ordered_symbols: if symbol.kind != "list": continue for token in symbol.contents: assert token not in concrete_to_symbol concrete_to_symbol[token] = symbol # (1b) rewrite any rule involving concrete 'x' from list 'y' to 'y::x' used_concretes = set() for symbol in ordered_symbols: if symbol.kind != "nonterm": continue new_contents = [] for rule in symbol.contents: new_rule = [] for token in rule: if token in self.name_to_symbol: new_rule.append(token) else: assert token in concrete_to_symbol, f"Token '{token}' is not in a list" new_rule.append(f"{concrete_to_symbol[token].name}::{token}") used_concretes.add(token) new_contents.append(new_rule) symbol.contents = new_contents # (1c) if 'y::x' appeared, turn 'y' into a nonterminal new_ordered_symbols = [] for symbol in ordered_symbols.copy(): if symbol.kind != "list": new_ordered_symbols.append(symbol) continue split_out = set(symbol.contents) & used_concretes if not split_out: new_ordered_symbols.append(symbol) continue new_rule = [] for token in sorted(split_out): name = f"{symbol.name}::{token}" self.name_to_symbol[name] = Symbol(name + " list") self.name_to_symbol[name].contents = [token] new_ordered_symbols.append(self.name_to_symbol[name]) new_rule.append([name]) left_in = set(symbol.contents) - used_concretes if left_in: name = f"{symbol.name}::__rest__" self.name_to_symbol[name] = Symbol(name + " list") self.name_to_symbol[name].contents = sorted(left_in) new_ordered_symbols.append(self.name_to_symbol[name]) new_rule.append([name]) symbol.kind = "nonterm" symbol.contents = new_rule symbol.production_names = [None for _ in new_rule] symbol.is_pseudo_node = True new_ordered_symbols.append(symbol) ordered_symbols = new_ordered_symbols # Done! ##### DESCRIBE the lexer and the symbols lines = [] def put(x): lines[-1] += x def putl(*x): lines.extend(x) putl("%option noyywrap", "%option reentrant", "%{", "typedef size_t prod_id_t;", "typedef size_t symbol_id_t;", "int OFFSET;", # https://stackoverflow.com/questions/47094667/getting-the-current-index-in-the-input-string-flex-lexer "#define YY_USER_ACTION OFFSET += yyleng;", ) self.max_n_productions = max(len(symbol.contents) + 1 for symbol in ordered_symbols if symbol.kind == "nonterm") putl(f"#define MAX_N_PRODUCTIONS {self.max_n_productions}") self.max_production_len = max(max(map(len, symbol.contents)) + 1 for symbol in ordered_symbols if symbol.kind == "nonterm") putl(f"#define MAX_PRODUCTION_LEN {self.max_production_len}") n_nonterms = len([symbol for symbol in ordered_symbols if symbol.kind == "nonterm"]) putl(f"#define N_NONTERMS {n_nonterms}") putl(f"#define N_SYMBOLS {len(ordered_symbols) + 1}") putl(f"#define DONE_SYMBOL 0") # 0, nonterm1, nonterm2, ..., nontermN, term, ... putl(f"#define IS_NONTERM(x) ((0 < (x)) && ((x) <= N_NONTERMS))") # put all the nonterminals at the beginning ordered_symbols = sorted(ordered_symbols, key=lambda s: (s.kind == "nonterm"), reverse=True) self.id_to_symbol = dict() putl("char *SYMBOL_ID_TO_NAME[] = { \"DONE\"") for i, symbol in enumerate(ordered_symbols): symbol.id = i + 1 self.id_to_symbol[symbol.id] = symbol put(", \"" + symbol.name + "\"") put(" };") for symbol in ordered_symbols: if symbol.name.replace("_", "").isalnum(): putl(f"#define SYMBOL_{symbol.name} {symbol.id}") if symbol.is_start: putl(f"#define START_SYMBOL {symbol.id}") putl("char SYMBOL_TO_POISON[] = { 0") for symbol in ordered_symbols: put(", " + ("1" if symbol.poisoned else "0")) put(" };") putl("prod_id_t SYMBOL_ID_TO_PRODUCTION_IDS[N_SYMBOLS][MAX_N_PRODUCTIONS] = { {0}") # [(production, Symbol), ...] self.productions = [([], None, None)] for symbol in ordered_symbols: if symbol.kind == "nonterm": start_idx = len(self.productions) assert isinstance(symbol.contents[0], list) for i, rule in enumerate(symbol.contents): rule = [self.name_to_symbol[x] for x in rule] self.productions.append((rule, symbol, symbol.production_names[i])) prods = ', '.join(map(str, range(start_idx, len(self.productions)))) if prods: put(", {" + prods + ", 0}") else: put(", {0}") else: self.productions.append(([], symbol, None)) put(", {0}") put(" };") putl(f"#define N_PRODUCTIONS {len(self.productions)}") for i, (_, _, name) in enumerate(self.productions): if name: putl(f"#define PRODUCTION_{name} {i}") putl("symbol_id_t PRODUCTION_ID_TO_PRODUCTION[N_PRODUCTIONS][MAX_PRODUCTION_LEN] = { {0}") for i, (production, _, _) in enumerate(self.productions): if i == 0: continue production = ', '.join(str(symbol.id) for symbol in production) if production: put(", {" + production + ", 0}") else: put(", {0}") put(" };") putl("symbol_id_t PRODUCTION_ID_TO_SYMBOL[N_PRODUCTIONS] = { 0") for i, (_, symbol, _) in enumerate(self.productions): if i != 0: put(f", {symbol.id}") put(" };") # Production hints: for this production, what does the leading symbol # need to be? # symbol -> symbol | True (multiple) symbol_to_first = {symbol: symbol for symbol in self.id_to_symbol.values() if symbol.kind != "nonterm"} fixedpoint = False while not fixedpoint: fixedpoint = True for symbol in self.id_to_symbol.values(): if symbol.kind != "nonterm": continue head_symbols = [self.name_to_symbol[production[0]] for production in symbol.contents] firsts = [symbol_to_first.get(head, None) for head in head_symbols] new_first = (firsts[0] if all(f == firsts[0] for f in firsts) else True) if symbol_to_first.get(symbol, None) != new_first: symbol_to_first[symbol] = new_first fixedpoint = False putl("symbol_id_t PRODUCTION_ID_TO_FIRST[N_PRODUCTIONS] = { 0") for i, (production, _, _) in enumerate(self.productions): if i == 0: continue if not production or symbol_to_first.get(production[0], True) is True: put(", 0") else: put(f", {symbol_to_first[production[0]].id}") put(" };") ##### DONE: output the lexer putl("void lex_symbol(symbol_id_t);") putl("%}") putl("%%") # Spit out the lexer! def escape_literal(lit): return '"' + lit.replace('\\', '\\\\') + '"' for symbol in ordered_symbols: if symbol.kind == "nonterm": continue if symbol.kind == "list": for token in symbol.contents: putl(escape_literal(token) + f" {{ lex_symbol({symbol.id}); }}") elif symbol.kind == "regex": putl(symbol.contents + f" {{ lex_symbol({symbol.id}); }}") else: raise NotImplementedError putl(". { }") putl("\\n { }") putl("%%") return lines class Symbol: def __init__(self, declaration): parts = declaration.split() self.name = parts[0] self.kind = parts[1] self.is_start = ".start" in parts[2:] self.poisoned = ".poison" in parts[2:] self.contents = [] self.production_names = [] self.id = None self.is_pseudo_node = False def process_subline(self, line): if self.kind == "list": self.contents.extend(line.split()) elif self.kind == "regex": assert not self.contents self.contents = line.strip() elif self.kind == "nonterm": self.contents.append(line.split()) self.production_names.append(None) for i, part in enumerate(self.contents[-1]): if part.startswith("."): args = self.contents[-1][i:] self.contents[-1] = self.contents[-1][:i] for arg, value in zip(args[::2], args[1::2]): if arg == ".name": self.production_names[-1] = value else: raise NotImplementedError class Node: def __init__(self, symbol, production): self.symbol = symbol self.production = production self.contents = [] def line_numbers(self): return self.contents[0].line_numbers() def max_line_numbers(self): return self.contents[-1].max_line_numbers() def file_name(self): return self.contents[-1].file_name() def pprint(self): def pprint(other): if isinstance(other, Node): return other.pprint() return other.pprint() if len(self.contents) == 1: return pprint(self.contents[0]) return '(' + ' '.join(map(pprint, self.contents)) + ')' def print_tree(self, depth=0): print((' ' * depth) + self.symbol.name) for arg in self.contents: arg.print_tree(depth + 2) def isa(self, *patterns): for pattern in patterns: if "->" in pattern: symbol, production = pattern.split("->") symbol = symbol.strip() if symbol != self.symbol.name: continue production = production.split() if production[-1] != "..." and len(production) != len(self.pprint_production().split()[2:]): continue for desired, real in zip(production, self.pprint_production().split()[2:]): if desired == "...": return True if desired != real: break else: return True else: symbol = pattern.strip() if symbol == self.symbol.name: return True return False def hasa(self, symbol): return any(sub.name == symbol for sub in self.production) def pprint_production(self): parts = [] for s in self.production: if "::" in s.name: parts.append(s.name[s.name.index("::")+2:]) else: parts.append(s.name) return f"{self.symbol.name} -> {' '.join(parts)}" def find(self, kind, which=0, total=1): found = [] for s in self.subtrees(): if s.symbol.name == kind: found.append(s) if len(found) != total: raise ValueError return found[which] def subtrees(self): return self.contents def __getitem__(self, i): return self.contents[i] class Token: def __init__(self, symbol, string, line_number, file_name): self.symbol = symbol self.string = string self.line_number = line_number self.file_name_ = file_name def pprint(self): return self.string def line_numbers(self): return {self.line_number} def file_name(self): return self.file_name_ def max_line_numbers(self): return self.line_numbers() def print_tree(self, depth=0): print((' ' * depth) + self.symbol.name , self.string , self.line_number)