from typing import * import textwrap import re import copy import json from collections import defaultdict from re import Pattern from agast import * global i i = 0 def gensym(prefix: str = "", suffix: str = "") -> str: global i presan = re.sub("[^0-9a-zA-Z]+", "_", prefix) sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix) i += 1 return f"{presan}{i}{sufsan}" class NodeDesc: def __init__(self, node: Node): self.node = node self.name = node.name self.nonterminal = gensym(node.name.lower()) class ParseEquation: def __init__(self, name: str, syms: List[str], pyty: str): self.name = name self.syms = syms self.ty = pyty class GenResult: def __init__(self, pd: str = "", ex: str = ""): self.literals: Dict[str, str] = dict() self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list) self.starts: Set[str] = set() self.extra = ex self.nonterminal_map: Dict[str, str] = dict() @property def transdef(self) -> str: s = [] for name, rules in self.parse_rules.items(): n = name.lstrip("?") for equation in rules: code = textwrap.dedent(f""" def {equation.name}(self, items: Any) -> Thunk[{equation.ty}]: def inner() -> {equation.ty}: res = {equation.ty}() return res return Thunk(inner) """) s.append(code) if not s: s = ["pass"] return textwrap.indent("\n".join(s), " ") @property def parser_data(self) -> str: s = [] for sym, pat in self.literals.items(): s.append(f"{sym}: {pat}") for name, rules in self.parse_rules.items(): names = [] for rule in rules: names.append(rule.name) s.append(f"{rule.name}: {' '.join(rule.syms)}") s.append(f"{name}: {' | '.join(names)}") s.append("%import common.WS") s.append("%ignore WS") return "\n".join(s) def gen(program: List[Decl]) -> GenResult: res = GenResult() def v(name: str) -> str: return f"__ag_{name}" # builtins builtins: Dict[str, str] = { "parseInt": "", } # collect a list of name -> iface declarations ifaces: Dict[str, Iface] = dict( map( lambda c: (c.name, cast(Iface, c)), filter(lambda c: isinstance(c, Iface), program), ) ) # list of node -> iface mappings what_ifaces: Dict[str, Set[str]] = dict() what_fields: Dict[str, Dict[str, str]] = dict() for node in filter(lambda c: isinstance(c, Node), program): node = cast(Node, node) # all_fields = dict() what_ifaces[node.name] = set(node.ifaces) this_fields = dict() for iface in node.ifaces: fields = ifaces[iface].fields for field in fields: if field.name in this_fields: raise Exception("duplicate field name") this_fields[field.name] = field.ty what_fields[node.name] = this_fields print("what_ifaces:", what_ifaces) print("what_fields:", what_fields) # a high-level dictionary of productions; this has sub-productions # that should be further expanded at a later step before converting # into lark code productions_hi: Dict[str, Union[str, List[str]]] = dict() # TODO: this should probably not be inlined here, but i'll move it # out once i get more info into the 'env' def collect_required_thunks( env: List[Tuple[str, NodeRef]], expr: Expr ) -> Dict[str, str]: names = dict(env) print(f"collect_required_thunks({expr})", expr.__class__) if isinstance(expr, ExprDot): return collect_required_thunks(env, expr.left) elif isinstance(expr, ExprMul): a = collect_required_thunks(env, expr.left) b = collect_required_thunks(env, expr.right) a.update(b) return a elif isinstance(expr, ExprAdd): a = collect_required_thunks(env, expr.left) b = collect_required_thunks(env, expr.right) a.update(b) return a elif isinstance(expr, ExprCall): return collect_required_thunks(env, expr.func) elif isinstance(expr, ExprName): if expr.name not in names and expr.name not in builtins: raise Exception(f"unbound name '{expr.name}'") return dict() raise Exception(f"unhandled {expr.__class__}") node_map: Dict[str, NodeDesc] = dict() for _node in filter(lambda c: isinstance(c, Node), program): nd = NodeDesc(cast(Node, _node)) node_map[_node.name] = nd res.nonterminal_map[nd.name] = nd.nonterminal for node_desc in node_map.values(): assert isinstance(node_desc, NodeDesc) res.starts.add(node_desc.nonterminal) class_decl = textwrap.dedent( f""" class {node_desc.nonterminal}: pass """ ) res.extra += class_decl print(node_desc.name, node_desc.node.ifaces) for variant in node_desc.node.variants: v_class_name = gensym(f"{node_desc.nonterminal}_var") class_decl = textwrap.dedent( f""" class {v_class_name}({node_desc.nonterminal}): ''' ''' pass """ ) res.extra += class_decl prod_name = gensym(node_desc.nonterminal + "_") print("PRODUCTION", prod_name, variant.prod) # resolving a production just means checking to make sure it's a # type that exists or it's a regex def resolve_production(sym: Sym) -> str: print(f"resolve_production({sym})") if isinstance(sym, SymRename): if isinstance(sym.ty, NodeRefByName): if sym.ty.name in node_map: return node_map[sym.ty.name].nonterminal else: raise Exception( f"unresolved name {sym.ty.name} in production" ) elif isinstance(sym.ty, NodeRegex): sym_name = gensym("sym") res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" return sym_name elif isinstance(sym, SymLit): sym_name = gensym("lit") # hack to make repr have double quotes res.literals[sym_name] = json.dumps(sym.lit) return sym_name raise Exception(f"unhandled {sym.__class__}") seq = [] for sym in variant.prod: n = resolve_production(sym) seq.append(n) res.parse_rules[node_desc.nonterminal].append( ParseEquation(prod_name, seq, v_class_name) ) # create an environment for checking the equations based on # the production env: List[Tuple[str, NodeRef]] = list() for sym in variant.prod: if isinstance(sym, SymRename): env.append((sym.name, sym.ty)) print(env) # for each of the equations, find out what the equation is # trying to compute, and generate a thunk corresponding to # that value. for eq in variant.equations: eq_name = gensym(f"eq_{node_desc.name}") thunk_name = gensym(f"thunk_{node_desc.name}") print("RHS", eq.rhs, eq.rhs.id) collect_required_thunks(copy.deepcopy(env), eq.rhs) return res