From e43175c53ba423f67347c95962a3af9ce4209002 Mon Sep 17 00:00:00 2001 From: Michael Zhang Date: Fri, 1 Oct 2021 14:35:42 -0500 Subject: [PATCH] added some more shit --- docs/conf.py | 19 +++++----- flake.nix | 6 +++ src/agtest/driver.py | 3 ++ src/agtest/gen.py | 77 +++++++++++++++++++++++++------------- src/agtest/runtime.tmpl.py | 2 +- 5 files changed, 71 insertions(+), 36 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index bee8c80..c300f45 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,14 +12,15 @@ # import os import sys -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) # -- Project information ----------------------------------------------------- -project = 'agtest' -copyright = '2021, Michael Zhang' -author = 'Michael Zhang ' +project = "agtest" +copyright = "2021, Michael Zhang" +author = "Michael Zhang " # -- General configuration --------------------------------------------------- @@ -27,15 +28,15 @@ author = 'Michael Zhang ' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -43,9 +44,9 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'furo' +html_theme = "furo" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] diff --git a/flake.nix b/flake.nix index 9501193..8bd8fae 100644 --- a/flake.nix +++ b/flake.nix @@ -17,5 +17,11 @@ in { packages = flake-utils.lib.flattenTree myPkgs; + + devShell = pkgs.mkShell { + packages = with pythonPkgs; [ + black + ]; + }; }); } diff --git a/src/agtest/driver.py b/src/agtest/driver.py index a8e54cc..4843267 100644 --- a/src/agtest/driver.py +++ b/src/agtest/driver.py @@ -16,6 +16,7 @@ grammar_path = path.join(src_dir, "grammar.lark") runtime_path = path.join(src_dir, "runtime.tmpl.py") p = lark.Lark(open(grammar_path).read(), start="program", parser="lalr") + @click.command() @click.option("--show-only", is_flag=True) @click.argument("input", type=click.File("r")) @@ -53,8 +54,10 @@ def run(input: TextIO, show_only: bool) -> None: print("Call parse(str) to parse something.") import imp + mod = imp.new_module("mod") exec(s.getvalue(), mod.__dict__) import code + code.InteractiveConsole(locals=mod.__dict__).interact() diff --git a/src/agtest/gen.py b/src/agtest/gen.py index 0c20ace..cc42db0 100644 --- a/src/agtest/gen.py +++ b/src/agtest/gen.py @@ -3,10 +3,16 @@ import textwrap import re import copy import json +import sys from collections import defaultdict from agtest.ast import * + +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + + global i i = 0 @@ -27,10 +33,11 @@ class NodeDesc: class ParseEquation: - def __init__(self, name: str, syms: List[str], pyty: str): + def __init__(self, name: str, syms: List[str], pyty: str, pycode: str): self.name = name self.syms = syms - self.ty = pyty + self.pyty = pyty + self.pycode = pycode class GenResult: @@ -55,13 +62,23 @@ class GenResult: def trans_def(self) -> str: s = [] for name, rules in self.parse_rules.items(): + possible_returns = ", ".join(map(lambda e: e.pyty, rules)) + n = name.lstrip("?") + code = textwrap.dedent( + f""" + def {n}(self, items: List[Union[{possible_returns}]]) -> {n}: + return items[0] + """ + ) + s.append(code) + for equation in rules: code = textwrap.dedent( f""" - def {equation.name}(self, items: Any) -> Thunk[{equation.ty}]: - def inner() -> {equation.ty}: - res = {equation.ty}() + def {equation.name}(self, items: Any) -> Thunk[{equation.pyty}]: + def inner() -> {equation.pyty}: + res = {equation.pyty}() return res return Thunk(inner) """ @@ -86,9 +103,8 @@ class GenResult: s.append("%ignore WS") return "\n".join(s) - def _collect_ifaces(self) -> None: - """ collect a list of name -> iface declarations""" + """collect a list of name -> iface declarations""" self.ifaces = dict( map( lambda c: (c.name, cast(Iface, c)), @@ -97,7 +113,7 @@ class GenResult: ) def _create_iface_mappings(self) -> None: - """ list of node -> iface mappings """ + """list of node -> iface mappings""" self.what_ifaces = dict() self.what_fields = dict() for node in filter(lambda c: isinstance(c, Node), self.program): @@ -137,25 +153,23 @@ class GenResult: return dict() raise Exception(f"unhandled {expr.__class__}") - def _resolve_production(self, sym: Sym) -> str: - """ resolving a production just means checking to make sure it's a type that exists or it's a regex""" + def _resolve_production(self, sym: Sym) -> Tuple[bool, str]: + """resolving a production just means checking to make sure it's a type that exists or it's a regex""" if isinstance(sym, SymRename): if isinstance(sym.ty, NodeRefByName): if sym.ty.name in self.node_map: - return self.node_map[sym.ty.name].nonterminal + return True, self.node_map[sym.ty.name].nonterminal else: - raise Exception( - f"unresolved name {sym.ty.name} in production" - ) + raise Exception(f"unresolved name {sym.ty.name} in production") elif isinstance(sym.ty, NodeRegex): sym_name = gensym("sym") self.literals[sym_name] = f"/{sym.ty.pat.pattern}/" - return sym_name + return True, sym_name elif isinstance(sym, SymLit): sym_name = gensym("lit") # hack to make repr have double quotes self.literals[sym_name] = json.dumps(sym.lit) - return sym_name + return False, sym_name raise Exception(f"unhandled {sym.__class__}") def _build_node_map(self) -> None: @@ -170,6 +184,9 @@ class GenResult: self._collect_ifaces() self._create_iface_mappings() + self._build_node_map() + + eprint("IFACE MAPS", self.what_fields, self.what_ifaces) # a high-level dictionary of productions; this has sub-productions that # should be further expanded at a later step before converting into lark @@ -186,12 +203,12 @@ class GenResult: g = textwrap.indent("\n".join(class_fields), " ") class_decl = textwrap.dedent( - f""" - class {node_desc.nonterminal}: + """ + class {nonterminal}: {g} pass """ - ) + ).format(nonterminal=node_desc.nonterminal, g=g) self.extra += class_decl # print(node_desc.name, node_desc.node.ifaces) @@ -199,22 +216,31 @@ class GenResult: for variant in node_desc.node.variants: v_class_name = gensym(f"{node_desc.nonterminal}_var") class_decl = textwrap.dedent( - f""" - class {v_class_name}({node_desc.nonterminal}): pass """ + class {v_class_name}({nonterminal}): + {g} + pass + """ + ).format( + v_class_name=v_class_name, nonterminal=node_desc.nonterminal, g=g ) self.extra += class_decl prod_name = gensym(node_desc.nonterminal + "_") # print("PRODUCTION", prod_name, variant.prod) - seq = [] - for sym in variant.prod: - n = self._resolve_production(sym) + inputs = [] + for i, sym in enumerate(variant.prod): + isInput, n = self._resolve_production(sym) + if isInput: + inputs.append((i, n)) seq.append(n) + eprint("INPUTS", node_desc.nonterminal, inputs) + + pycode = "" self.parse_rules[node_desc.nonterminal].append( - ParseEquation(prod_name, seq, v_class_name) + ParseEquation(prod_name, seq, v_class_name, pycode) ) # create an environment for checking the equations based on the @@ -229,4 +255,3 @@ class GenResult: # value. for eq in variant.equations: self._collect_required_thunks(copy.deepcopy(env), eq.rhs) - diff --git a/src/agtest/runtime.tmpl.py b/src/agtest/runtime.tmpl.py index 4abcb45..b784a1a 100644 --- a/src/agtest/runtime.tmpl.py +++ b/src/agtest/runtime.tmpl.py @@ -3,7 +3,7 @@ __all__ = ["parse"] import re -from typing import Generic, TypeVar, Optional, Callable, Dict, Any +from typing import Generic, TypeVar, Optional, Callable, Dict, Any, Union, List from lark import Lark, Transformer T = TypeVar("T")