From 9dbb557d003af6b9b307faf34b0325bae153d745 Mon Sep 17 00:00:00 2001 From: Michael Zhang Date: Thu, 30 Sep 2021 17:49:26 -0500 Subject: [PATCH] more reorganizing --- aggen.py | 46 ++++++++++++++++++++++------------------------ agmain.py | 6 +++++- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/aggen.py b/aggen.py index 70b038d..49c8160 100644 --- a/aggen.py +++ b/aggen.py @@ -24,14 +24,14 @@ class NodeDesc: def __init__(self, node: Node): self.node = node self.name = node.name - self.nonterminal = node.name.lower() + self.nonterminal = gensym(node.name.lower()) class ParseEquation: - def __init__(self, name: str, syms: List[str], ty: str): + def __init__(self, name: str, syms: List[str], pyty: str): self.name = name self.syms = syms - self.ty = ty + self.ty = pyty class GenResult: @@ -40,23 +40,25 @@ class GenResult: self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list) self.starts: Set[str] = set() self.extra = ex - self.trans: List[str] = list() + self.nonterminal_map: Dict[str, str] = dict() @property def transdef(self) -> str: - s = self.trans + s = [] for name, rules in self.parse_rules.items(): n = name.lstrip("?") for equation in rules: - code = f""" + code = textwrap.dedent(f""" def {equation.name}(self, items: Any) -> Thunk[{equation.ty}]: - return Thunk(lambda: {equation.ty}()) - """.strip().replace("\n", "") - code = re.sub(r"\s+", " ", code) + def inner() -> {equation.ty}: + res = {equation.ty}() + return res + return Thunk(inner) + """) s.append(code) if not s: s = ["pass"] - return "\n" + "\n".join(map(lambda c: f" {c}", s)) + return textwrap.indent("\n".join(s), " ") @property def parser_data(self) -> str: @@ -143,22 +145,16 @@ def gen(program: List[Decl]) -> GenResult: return dict() raise Exception(f"unhandled {expr.__class__}") - node_map: Dict[str, NodeDesc] = dict( - map( - lambda n: (n.name, NodeDesc(cast(Node, n))), - filter(lambda c: isinstance(c, Node), program), - ) - ) + node_map: Dict[str, NodeDesc] = dict() + + for _node in filter(lambda c: isinstance(c, Node), program): + nd = NodeDesc(cast(Node, _node)) + node_map[_node.name] = nd + res.nonterminal_map[nd.name] = nd.nonterminal for node_desc in node_map.values(): assert isinstance(node_desc, NodeDesc) - - res.starts.add(node_desc.name.lower()) - # res.parse_rules[f"?{node_desc.name.lower()}"].append( - # ParseEquation( - # node_desc.name.lower(), [node_desc.nonterminal], node_desc.nonterminal - # ) - # ) + res.starts.add(node_desc.nonterminal) class_decl = textwrap.dedent( f""" @@ -210,7 +206,9 @@ def gen(program: List[Decl]) -> GenResult: for sym in variant.prod: n = resolve_production(sym) seq.append(n) - res.parse_rules[node_desc.nonterminal].append(ParseEquation(prod_name, seq, v_class_name)) + res.parse_rules[node_desc.nonterminal].append( + ParseEquation(prod_name, seq, v_class_name) + ) # create an environment for checking the equations based on # the production diff --git a/agmain.py b/agmain.py index e6908c5..cd49b2e 100644 --- a/agmain.py +++ b/agmain.py @@ -1,5 +1,6 @@ import textwrap import os +import json import importlib from lark import Lark @@ -55,7 +56,9 @@ if __name__ == "__main__": class Trans(Transformer[None]): {transdef} + __agNonterminals = {ntmap} def parse(input: str, start: Optional[str] = None) -> Any: + if start is not None: start = __agNonterminals[start] tree = parser.parse(input, start) trans = Trans() res = trans.transform(tree) @@ -68,8 +71,9 @@ if __name__ == "__main__": ex=res.extra, starts=list(res.starts), transdef=res.transdef, + ntmap=json.dumps(res.nonterminal_map), ) ) mod = importlib.import_module("gen.arith") - print(mod.parse("1 + 2 * 3", start="expr")) # type: ignore + print(mod.parse("1 + 2 * 3", start="Expr")) # type: ignore