diff --git a/Justfile b/Justfile index dbf1a57..116ef8a 100644 --- a/Justfile +++ b/Justfile @@ -3,6 +3,7 @@ watch: fmt: pipenv run black . + dos2unix * run: mypy *.py diff --git a/aggen.py b/aggen.py index 8c666a2..f05f1a9 100644 --- a/aggen.py +++ b/aggen.py @@ -12,6 +12,21 @@ global i i = 0 +def gensym(prefix: str = "", suffix: str = "") -> str: + global i + presan = re.sub("[^0-9a-zA-Z]+", "_", prefix) + sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix) + i += 1 + return f"{presan}{i}{sufsan}" + + +class NodeDesc: + def __init__(self, node: Node): + self.node = node + self.name = node.name + self.nonterminal = gensym(node.name.lower()) + + class GenResult: def __init__(self, pd: str = "", ex: str = ""): self.literals: Dict[str, str] = dict() @@ -43,13 +58,6 @@ class GenResult: def gen(program: List[Decl]) -> GenResult: res = GenResult() - def gen(prefix: str = "", suffix: str = "") -> str: - global i - presan = re.sub("[^0-9a-zA-Z]+", "_", prefix) - sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix) - i += 1 - return f"{presan}{i}{sufsan}" - def v(name: str) -> str: return f"__ag_{name}" @@ -116,42 +124,41 @@ def gen(program: List[Decl]) -> GenResult: return dict() raise Exception(f"unhandled {expr.__class__}") - node_map = dict( - map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program)) - ) - node_name_map = dict( - map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items()) + node_map: Dict[str, NodeDesc] = dict( + map( + lambda n: (n.name, NodeDesc(cast(Node, n))), + filter(lambda c: isinstance(c, Node), program), + ) ) - for node in node_map.values(): - node = cast(Node, node) - node_name_lower = node.name.lower() - node_name = node_name_map[node.name] - res.parse_rules[f"?{node_name_lower}"].append(node_name) - res.starts.add(node_name_lower) + for node_desc in node_map.values(): + assert isinstance(node_desc, NodeDesc) + + res.starts.add(node_desc.name.lower()) + res.parse_rules[f"?{node_desc.name.lower()}"].append(node_desc.nonterminal) class_decl = textwrap.dedent( f""" - class {v(node_name)}: pass + class {node_desc.nonterminal}: pass """ ) res.extra += class_decl - print(node.name, node.ifaces) + print(node_desc.name, node_desc.node.ifaces) - for variant in node.variants: - v_class_name = gen(f"{node_name}_var") + for variant in node_desc.node.variants: + v_class_name = gensym(f"{node_desc.nonterminal}_var") class_decl = textwrap.dedent( f""" - class {v(v_class_name)}({v(node_name)}): + class {v_class_name}({node_desc.nonterminal}): ''' ''' pass """ ) res.extra += class_decl - prod_name = gen(node_name + "_") - res.parse_rules[node_name].append(prod_name) + prod_name = gensym(node_desc.nonterminal + "_") + res.parse_rules[node_desc.nonterminal].append(prod_name) print("PRODUCTION", prod_name, variant.prod) # resolving a production just means checking to make sure it's a @@ -160,18 +167,18 @@ def gen(program: List[Decl]) -> GenResult: print(f"resolve_production({sym})") if isinstance(sym, SymRename): if isinstance(sym.ty, NodeRefByName): - if sym.ty.name in node_name_map: - return node_name_map[sym.ty.name] + if sym.ty.name in node_map: + return node_map[sym.ty.name].nonterminal else: raise Exception( f"unresolved name {sym.ty.name} in production" ) elif isinstance(sym.ty, NodeRegex): - sym_name = gen("sym") + sym_name = gensym("sym") res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" return sym_name elif isinstance(sym, SymLit): - sym_name = gen("lit") + sym_name = gensym("lit") # hack to make repr have double quotes res.literals[sym_name] = json.dumps(sym.lit) return sym_name @@ -195,8 +202,8 @@ def gen(program: List[Decl]) -> GenResult: # trying to compute, and generate a thunk corresponding to # that value. for eq in variant.equations: - eq_name = gen(f"eq_{node.name}") - thunk_name = gen(f"thunk_{node.name}") + eq_name = gensym(f"eq_{node_desc.name}") + thunk_name = gensym(f"thunk_{node_desc.name}") print("RHS", eq.rhs, eq.rhs.id) collect_required_thunks(copy.deepcopy(env), eq.rhs) @@ -213,8 +220,4 @@ def gen(program: List[Decl]) -> GenResult: print(f"```py\n{func_impl}\n```") res.extra += func_impl - # this is a "type alias" that connects it to one of the generated - # names above - res.extra += f"{node.name} = {v(node_name)}" - return res diff --git a/agmain.py b/agmain.py index 6a8232f..a2c1302 100644 --- a/agmain.py +++ b/agmain.py @@ -9,7 +9,7 @@ from aggen import * p = Lark(open("grammar.lark").read(), start="program", parser="lalr") if __name__ == "__main__": - with open("arith.ag") as f: + with open("test/arith.ag") as f: data = f.read() cst = p.parse(data) diff --git a/arith.ag b/test/arith.ag similarity index 100% rename from arith.ag rename to test/arith.ag diff --git a/let.ag b/test/let.ag similarity index 100% rename from let.ag rename to test/let.ag