diff --git a/flake.nix b/flake.nix index 41903ef..9501193 100644 --- a/flake.nix +++ b/flake.nix @@ -9,8 +9,10 @@ flake-utils.lib.eachDefaultSystem (system: let pkgs = nixpkgs.legacyPackages.${system}; + pythonPkgs = pkgs.python39Packages; + myPkgs = rec { - agtest = pkgs.python39Packages.callPackage ./. {}; + agtest = pythonPkgs.callPackage ./. {}; }; in { diff --git a/setup.cfg b/setup.cfg index c198c76..176d750 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ name = click [options.entry_points] console_scripts = - agt = agtest.driver:run + agtest = agtest.driver:run [options] packages = find: diff --git a/src/agtest/driver.py b/src/agtest/driver.py index 3630d51..85cd2c6 100644 --- a/src/agtest/driver.py +++ b/src/agtest/driver.py @@ -1,6 +1,7 @@ import importlib import json import os +import sys from io import StringIO from os import path @@ -17,8 +18,9 @@ runtime_path = path.join(src_dir, "runtime.tmpl.py") p = lark.Lark(open(grammar_path).read(), start="program", parser="lalr") @click.command() +@click.option("--show-only", is_flag=True) @click.argument("input", type=click.File("r")) -def run(input): +def run(input, show_only): data = input.read() input.close() @@ -44,6 +46,10 @@ def run(input): ) ) + if show_only: + print(s.getvalue()) + sys.exit(0) + print("Dropping you in a Python shell...") print("Call parse(str) to parse something.") diff --git a/src/agtest/gen.py b/src/agtest/gen.py index e114346..0c20ace 100644 --- a/src/agtest/gen.py +++ b/src/agtest/gen.py @@ -42,6 +42,15 @@ class GenResult: self.extra: str = "" self.nonterminal_map: Dict[str, str] = dict() + self.ifaces: Dict[str, Iface] = dict() + self.what_ifaces: Dict[str, Set[str]] = dict() + self.what_fields: Dict[str, Dict[str, str]] = dict() + self.node_map: Dict[str, NodeDesc] = dict() + + self.builtins: Dict[str, str] = { + "parseInt": "", + } + @property def trans_def(self) -> str: s = [] @@ -78,83 +87,101 @@ class GenResult: return "\n".join(s) - def build(self) -> None: - def v(name: str) -> str: - return f"__ag_{name}" - - # builtins - builtins: Dict[str, str] = { - "parseInt": "", - } - - # collect a list of name -> iface declarations - ifaces: Dict[str, Iface] = dict( + def _collect_ifaces(self) -> None: + """ collect a list of name -> iface declarations""" + self.ifaces = dict( map( lambda c: (c.name, cast(Iface, c)), filter(lambda c: isinstance(c, Iface), self.program), ) ) - # list of node -> iface mappings - what_ifaces: Dict[str, Set[str]] = dict() - what_fields: Dict[str, Dict[str, str]] = dict() + def _create_iface_mappings(self) -> None: + """ list of node -> iface mappings """ + self.what_ifaces = dict() + self.what_fields = dict() for node in filter(lambda c: isinstance(c, Node), self.program): node = cast(Node, node) # all_fields = dict() - what_ifaces[node.name] = set(node.ifaces) + self.what_ifaces[node.name] = set(node.ifaces) this_fields = dict() for iface in node.ifaces: - fields = ifaces[iface].fields + fields = self.ifaces[iface].fields for field in fields: if field.name in this_fields: raise Exception("duplicate field name") this_fields[field.name] = field.ty - what_fields[node.name] = this_fields + self.what_fields[node.name] = this_fields + + def _collect_required_thunks( + self, env: List[Tuple[str, NodeRef]], expr: Expr + ) -> Dict[str, str]: + names = dict(env) + if isinstance(expr, ExprDot): + return self._collect_required_thunks(env, expr.left) + elif isinstance(expr, ExprMul): + a = self._collect_required_thunks(env, expr.left) + b = self._collect_required_thunks(env, expr.right) + a.update(b) + return a + elif isinstance(expr, ExprAdd): + a = self._collect_required_thunks(env, expr.left) + b = self._collect_required_thunks(env, expr.right) + a.update(b) + return a + elif isinstance(expr, ExprCall): + return self._collect_required_thunks(env, expr.func) + elif isinstance(expr, ExprName): + if expr.name not in names and expr.name not in self.builtins: + raise Exception(f"unbound name '{expr.name}'") + return dict() + raise Exception(f"unhandled {expr.__class__}") + + def _resolve_production(self, sym: Sym) -> str: + """ resolving a production just means checking to make sure it's a type that exists or it's a regex""" + if isinstance(sym, SymRename): + if isinstance(sym.ty, NodeRefByName): + if sym.ty.name in self.node_map: + return self.node_map[sym.ty.name].nonterminal + else: + raise Exception( + f"unresolved name {sym.ty.name} in production" + ) + elif isinstance(sym.ty, NodeRegex): + sym_name = gensym("sym") + self.literals[sym_name] = f"/{sym.ty.pat.pattern}/" + return sym_name + elif isinstance(sym, SymLit): + sym_name = gensym("lit") + # hack to make repr have double quotes + self.literals[sym_name] = json.dumps(sym.lit) + return sym_name + raise Exception(f"unhandled {sym.__class__}") + + def _build_node_map(self) -> None: + for _node in filter(lambda c: isinstance(c, Node), self.program): + nd = NodeDesc(cast(Node, _node)) + self.node_map[_node.name] = nd + self.nonterminal_map[nd.name] = nd.nonterminal + + def build(self) -> None: + def v(name: str) -> str: + return f"__ag_{name}" + + self._collect_ifaces() + self._create_iface_mappings() # a high-level dictionary of productions; this has sub-productions that # should be further expanded at a later step before converting into lark # code productions_hi: Dict[str, Union[str, List[str]]] = dict() - # TODO: this should probably not be inlined here, but i'll move it out once - # i get more info into the 'env' - def collect_required_thunks( - env: List[Tuple[str, NodeRef]], expr: Expr - ) -> Dict[str, str]: - names = dict(env) - if isinstance(expr, ExprDot): - return collect_required_thunks(env, expr.left) - elif isinstance(expr, ExprMul): - a = collect_required_thunks(env, expr.left) - b = collect_required_thunks(env, expr.right) - a.update(b) - return a - elif isinstance(expr, ExprAdd): - a = collect_required_thunks(env, expr.left) - b = collect_required_thunks(env, expr.right) - a.update(b) - return a - elif isinstance(expr, ExprCall): - return collect_required_thunks(env, expr.func) - elif isinstance(expr, ExprName): - if expr.name not in names and expr.name not in builtins: - raise Exception(f"unbound name '{expr.name}'") - return dict() - raise Exception(f"unhandled {expr.__class__}") - - node_map: Dict[str, NodeDesc] = dict() - - for _node in filter(lambda c: isinstance(c, Node), self.program): - nd = NodeDesc(cast(Node, _node)) - node_map[_node.name] = nd - self.nonterminal_map[nd.name] = nd.nonterminal - - for node_desc in node_map.values(): + for node_desc in self.node_map.values(): assert isinstance(node_desc, NodeDesc) self.starts.add(node_desc.nonterminal) class_fields = [] - for field_name, field_ty in what_fields[node_desc.name].items(): + for field_name, field_ty in self.what_fields[node_desc.name].items(): class_fields.append(f"{field_name}: Thunk[{field_ty}]") g = textwrap.indent("\n".join(class_fields), " ") @@ -181,31 +208,10 @@ class GenResult: prod_name = gensym(node_desc.nonterminal + "_") # print("PRODUCTION", prod_name, variant.prod) - # resolving a production just means checking to make sure it's a - # type that exists or it's a regex - def resolve_production(sym: Sym) -> str: - if isinstance(sym, SymRename): - if isinstance(sym.ty, NodeRefByName): - if sym.ty.name in node_map: - return node_map[sym.ty.name].nonterminal - else: - raise Exception( - f"unresolved name {sym.ty.name} in production" - ) - elif isinstance(sym.ty, NodeRegex): - sym_name = gensym("sym") - self.literals[sym_name] = f"/{sym.ty.pat.pattern}/" - return sym_name - elif isinstance(sym, SymLit): - sym_name = gensym("lit") - # hack to make repr have double quotes - self.literals[sym_name] = json.dumps(sym.lit) - return sym_name - raise Exception(f"unhandled {sym.__class__}") seq = [] for sym in variant.prod: - n = resolve_production(sym) + n = self._resolve_production(sym) seq.append(n) self.parse_rules[node_desc.nonterminal].append( ParseEquation(prod_name, seq, v_class_name) @@ -217,14 +223,10 @@ class GenResult: for sym in variant.prod: if isinstance(sym, SymRename): env.append((sym.name, sym.ty)) - print(env) - # for each of the equations, find out what the equation is trying - # to compute, and generate a thunk corresponding to that value. + # for each of the equations, find out what the equation is + # trying to compute, and generate a thunk corresponding to that + # value. for eq in variant.equations: - eq_name = gensym(f"eq_{node_desc.name}") - thunk_name = gensym(f"thunk_{node_desc.name}") - - print("RHS", eq.rhs, eq.rhs.id) - collect_required_thunks(copy.deepcopy(env), eq.rhs) + self._collect_required_thunks(copy.deepcopy(env), eq.rhs)