From 61bfd16b815e7aa5749545c66fd855ba6254fa8f Mon Sep 17 00:00:00 2001 From: Michael Zhang Date: Thu, 30 Sep 2021 16:07:36 -0500 Subject: [PATCH] a --- Justfile | 3 + aggen.py | 432 +++++++++++++++++++++++++++--------------------------- agmain.py | 17 ++- 3 files changed, 236 insertions(+), 216 deletions(-) diff --git a/Justfile b/Justfile index 5592315..dbf1a57 100644 --- a/Justfile +++ b/Justfile @@ -1,6 +1,9 @@ watch: watchexec -ce py,lark,ag -i gen 'just run' +fmt: + pipenv run black . + run: mypy *.py python agmain.py diff --git a/aggen.py b/aggen.py index 924661d..e9fa66f 100644 --- a/aggen.py +++ b/aggen.py @@ -1,212 +1,220 @@ -from typing import * -import textwrap -import re -import copy -import json -from collections import defaultdict -from re import Pattern - -from agast import * - -global i -i = 0 - - -class GenResult: - def __init__(self, pd: str = "", ex: str = ""): - self.literals: Dict[str, str] = dict() - self.parse_rules: defaultdict[str, List[str]] = defaultdict(list) - self.starts: Set[str] = set() - self.extra = ex - - @property - def parser_data(self) -> str: - s = [] - for sym, pat in self.literals.items(): - s.append(f"{sym}: {pat}") - for name, rules in self.parse_rules.items(): - srules = " | ".join(rules) - s.append(f"{name}: {srules}") - s.append("%import common.WS") - s.append("%ignore WS") - return "\n".join(s) - - -def gen(program: List[Decl]) -> GenResult: - res = GenResult() - - def gen(prefix: str = "", suffix: str = "") -> str: - global i - presan = re.sub("[^0-9a-zA-Z]+", "_", prefix) - sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix) - i += 1 - return f"{presan}{i}{sufsan}" - - def v(name: str) -> str: - return f"__ag_{name}" - - # builtins - builtins: Dict[str, str] = { - "parseInt": "", - } - - # collect a list of name -> iface declarations - ifaces: Dict[str, Iface] = dict( - map( - lambda c: (c.name, cast(Iface, c)), - filter(lambda c: isinstance(c, Iface), program), - ) - ) - - # list of node -> iface mappings - what_ifaces: Dict[str, Set[str]] = dict() - what_fields: Dict[str, Dict[str, str]] = dict() - for node in filter(lambda c: isinstance(c, Node), program): - node = cast(Node, node) - # all_fields = dict() - what_ifaces[node.name] = set(node.ifaces) - this_fields = dict() - for iface in node.ifaces: - fields = ifaces[iface].fields - for field in fields: - if field.name in this_fields: - raise Exception("duplicate field name") - this_fields[field.name] = field.ty - what_fields[node.name] = this_fields - print("what_ifaces:", what_ifaces) - print("what_fields:", what_fields) - - # a high-level dictionary of productions; this has sub-productions - # that should be further expanded at a later step before converting - # into lark code - productions_hi: Dict[str, Union[str, List[str]]] = dict() - - # TODO: this should probably not be inlined here, but i'll move it - # out once i get more info into the 'env' - def collect_required_thunks( - env: List[Tuple[str, NodeRef]], expr: Expr - ) -> Dict[str, str]: - names = dict(env) - print(f"collect_required_thunks({expr})", expr.__class__) - if isinstance(expr, ExprDot): - return collect_required_thunks(env, expr.left) - elif isinstance(expr, ExprMul): - a = collect_required_thunks(env, expr.left) - b = collect_required_thunks(env, expr.right) - a.update(b) - return a - elif isinstance(expr, ExprAdd): - a = collect_required_thunks(env, expr.left) - b = collect_required_thunks(env, expr.right) - a.update(b) - return a - elif isinstance(expr, ExprCall): - return collect_required_thunks(env, expr.func) - elif isinstance(expr, ExprName): - if expr.name not in names and expr.name not in builtins: - raise Exception(f"unbound name '{expr.name}'") - return dict() - raise Exception(f"unhandled {expr.__class__}") - - node_map = dict( - map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program)) - ) - node_name_map = dict( - map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items()) - ) - - for node in node_map.values(): - node = cast(Node, node) - node_name_lower = node.name.lower() - node_name = node_name_map[node.name] - res.parse_rules[f"?{node_name_lower}"].append(node_name) - res.starts.add(node_name_lower) - - class_decl = textwrap.dedent( - f""" - class {v(node_name)}: pass - """ - ) - res.extra += class_decl - - print(node.name, node.ifaces) - - for variant in node.variants: - v_class_name = gen(f"{node_name}_var") - class_decl = textwrap.dedent( - f""" - class {v(v_class_name)}({v(node_name)}): - ''' ''' - pass - """ - ) - res.extra += class_decl - - prod_name = gen(node_name + "_") - res.parse_rules[node_name].append(prod_name) - print("PRODUCTION", prod_name, variant.prod) - - # resolving a production just means checking to make sure it's a - # type that exists or it's a regex - def resolve_production(sym: Sym) -> str: - print(f"resolve_production({sym})") - if isinstance(sym, SymRename): - if isinstance(sym.ty, NodeRefByName): - if sym.ty.name in node_name_map: - return node_name_map[sym.ty.name] - else: - raise Exception( - f"unresolved name {sym.ty.name} in production" - ) - elif isinstance(sym.ty, NodeRegex): - sym_name = gen("sym") - res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" - return sym_name - elif isinstance(sym, SymLit): - sym_name = gen("lit") - # hack to make repr have double quotes - res.literals[sym_name] = json.dumps(sym.lit) - return sym_name - raise Exception(f"unhandled {sym.__class__}") - - seq = [] - for sym in variant.prod: - n = resolve_production(sym) - seq.append(n) - res.parse_rules[prod_name].append(" ".join(seq)) - - # create an environment for checking the equations based on - # the production - env: List[Tuple[str, NodeRef]] = list() - for sym in variant.prod: - if isinstance(sym, SymRename): - env.append((sym.name, sym.ty)) - print(env) - - # for each of the equations, find out what the equation is - # trying to compute, and generate a thunk corresponding to - # that value. - for eq in variant.equations: - eq_name = gen(f"eq_{node.name}") - thunk_name = gen(f"thunk_{node.name}") - - print("RHS", eq.rhs, eq.rhs.id) - collect_required_thunks(copy.deepcopy(env), eq.rhs) - - func_impl = textwrap.dedent( - f""" - def {eq_name}() -> None: - ''' {repr(eq)} ''' - pass - def {thunk_name}() -> Thunk[None]: - return Thunk({eq_name}) - """ - ) - print(f"```py\n{func_impl}\n```") - res.extra += func_impl - - # this is a "type alias" that connects it to one of the generated - # names above - res.extra += f"{node.name} = {v(node_name)}" - - return res +from typing import * +import textwrap +import re +import copy +import json +from collections import defaultdict +from re import Pattern + +from agast import * + +global i +i = 0 + + +class GenResult: + def __init__(self, pd: str = "", ex: str = ""): + self.literals: Dict[str, str] = dict() + self.parse_rules: defaultdict[str, List[str]] = defaultdict(list) + self.starts: Set[str] = set() + self.extra = ex + self.trans: List[str] = list() + + @property + def transdef(self) -> str: + s = self.trans + if not s: + s = ["pass"] + return "\n" + "\n".join(map(lambda c: f" {c}", s)) + + @property + def parser_data(self) -> str: + s = [] + for sym, pat in self.literals.items(): + s.append(f"{sym}: {pat}") + for name, rules in self.parse_rules.items(): + srules = " | ".join(rules) + s.append(f"{name}: {srules}") + s.append("%import common.WS") + s.append("%ignore WS") + return "\n".join(s) + + +def gen(program: List[Decl]) -> GenResult: + res = GenResult() + + def gen(prefix: str = "", suffix: str = "") -> str: + global i + presan = re.sub("[^0-9a-zA-Z]+", "_", prefix) + sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix) + i += 1 + return f"{presan}{i}{sufsan}" + + def v(name: str) -> str: + return f"__ag_{name}" + + # builtins + builtins: Dict[str, str] = { + "parseInt": "", + } + + # collect a list of name -> iface declarations + ifaces: Dict[str, Iface] = dict( + map( + lambda c: (c.name, cast(Iface, c)), + filter(lambda c: isinstance(c, Iface), program), + ) + ) + + # list of node -> iface mappings + what_ifaces: Dict[str, Set[str]] = dict() + what_fields: Dict[str, Dict[str, str]] = dict() + for node in filter(lambda c: isinstance(c, Node), program): + node = cast(Node, node) + # all_fields = dict() + what_ifaces[node.name] = set(node.ifaces) + this_fields = dict() + for iface in node.ifaces: + fields = ifaces[iface].fields + for field in fields: + if field.name in this_fields: + raise Exception("duplicate field name") + this_fields[field.name] = field.ty + what_fields[node.name] = this_fields + print("what_ifaces:", what_ifaces) + print("what_fields:", what_fields) + + # a high-level dictionary of productions; this has sub-productions + # that should be further expanded at a later step before converting + # into lark code + productions_hi: Dict[str, Union[str, List[str]]] = dict() + + # TODO: this should probably not be inlined here, but i'll move it + # out once i get more info into the 'env' + def collect_required_thunks( + env: List[Tuple[str, NodeRef]], expr: Expr + ) -> Dict[str, str]: + names = dict(env) + print(f"collect_required_thunks({expr})", expr.__class__) + if isinstance(expr, ExprDot): + return collect_required_thunks(env, expr.left) + elif isinstance(expr, ExprMul): + a = collect_required_thunks(env, expr.left) + b = collect_required_thunks(env, expr.right) + a.update(b) + return a + elif isinstance(expr, ExprAdd): + a = collect_required_thunks(env, expr.left) + b = collect_required_thunks(env, expr.right) + a.update(b) + return a + elif isinstance(expr, ExprCall): + return collect_required_thunks(env, expr.func) + elif isinstance(expr, ExprName): + if expr.name not in names and expr.name not in builtins: + raise Exception(f"unbound name '{expr.name}'") + return dict() + raise Exception(f"unhandled {expr.__class__}") + + node_map = dict( + map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program)) + ) + node_name_map = dict( + map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items()) + ) + + for node in node_map.values(): + node = cast(Node, node) + node_name_lower = node.name.lower() + node_name = node_name_map[node.name] + res.parse_rules[f"?{node_name_lower}"].append(node_name) + res.starts.add(node_name_lower) + + class_decl = textwrap.dedent( + f""" + class {v(node_name)}: pass + """ + ) + res.extra += class_decl + + print(node.name, node.ifaces) + + for variant in node.variants: + v_class_name = gen(f"{node_name}_var") + class_decl = textwrap.dedent( + f""" + class {v(v_class_name)}({v(node_name)}): + ''' ''' + pass + """ + ) + res.extra += class_decl + + prod_name = gen(node_name + "_") + res.parse_rules[node_name].append(prod_name) + print("PRODUCTION", prod_name, variant.prod) + + # resolving a production just means checking to make sure it's a + # type that exists or it's a regex + def resolve_production(sym: Sym) -> str: + print(f"resolve_production({sym})") + if isinstance(sym, SymRename): + if isinstance(sym.ty, NodeRefByName): + if sym.ty.name in node_name_map: + return node_name_map[sym.ty.name] + else: + raise Exception( + f"unresolved name {sym.ty.name} in production" + ) + elif isinstance(sym.ty, NodeRegex): + sym_name = gen("sym") + res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" + return sym_name + elif isinstance(sym, SymLit): + sym_name = gen("lit") + # hack to make repr have double quotes + res.literals[sym_name] = json.dumps(sym.lit) + return sym_name + raise Exception(f"unhandled {sym.__class__}") + + seq = [] + for sym in variant.prod: + n = resolve_production(sym) + seq.append(n) + res.parse_rules[prod_name].append(" ".join(seq)) + + # create an environment for checking the equations based on + # the production + env: List[Tuple[str, NodeRef]] = list() + for sym in variant.prod: + if isinstance(sym, SymRename): + env.append((sym.name, sym.ty)) + print(env) + + # for each of the equations, find out what the equation is + # trying to compute, and generate a thunk corresponding to + # that value. + for eq in variant.equations: + eq_name = gen(f"eq_{node.name}") + thunk_name = gen(f"thunk_{node.name}") + + print("RHS", eq.rhs, eq.rhs.id) + collect_required_thunks(copy.deepcopy(env), eq.rhs) + + func_impl = textwrap.dedent( + f""" + def {eq_name}() -> None: + ''' {repr(eq)} ''' + pass + def {thunk_name}() -> Thunk[None]: + return Thunk({eq_name}) + """ + ) + print(f"```py\n{func_impl}\n```") + res.extra += func_impl + + # this is a "type alias" that connects it to one of the generated + # names above + res.extra += f"{node.name} = {v(node_name)}" + + return res diff --git a/agmain.py b/agmain.py index 644e68b..6a8232f 100644 --- a/agmain.py +++ b/agmain.py @@ -25,6 +25,8 @@ if __name__ == "__main__": with open("gen/arith.py", "w") as f: fmt_str = textwrap.dedent( """ + # This documented generated by agtest. + __all__ = ["parse"] from typing import Generic, TypeVar, Optional, Callable, Dict, Any from lark import Lark, Transformer @@ -43,15 +45,22 @@ if __name__ == "__main__": self.value = self.func() return self.value parser = Lark('''{pd}''', parser='lalr', start={starts}, debug=True) - class Trans(Transformer[None]): - pass + class Trans(Transformer[None]): {transdef} {ex} def parse(input: str, start: Optional[str] = None) -> Any: - return parser.parse(input, start) + tree = parser.parse(input, start) + trans = Trans() + res = trans.transform(tree) + return res """ ) f.write( - fmt_str.format(pd=res.parser_data, ex=res.extra, starts=list(res.starts)) + fmt_str.format( + pd=res.parser_data, + ex=res.extra, + starts=list(res.starts), + transdef=res.transdef, + ) ) mod = importlib.import_module("gen.arith")