diff --git a/aggen.py b/aggen.py index f05f1a9..70b038d 100644 --- a/aggen.py +++ b/aggen.py @@ -24,13 +24,20 @@ class NodeDesc: def __init__(self, node: Node): self.node = node self.name = node.name - self.nonterminal = gensym(node.name.lower()) + self.nonterminal = node.name.lower() + + +class ParseEquation: + def __init__(self, name: str, syms: List[str], ty: str): + self.name = name + self.syms = syms + self.ty = ty class GenResult: def __init__(self, pd: str = "", ex: str = ""): self.literals: Dict[str, str] = dict() - self.parse_rules: defaultdict[str, List[str]] = defaultdict(list) + self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list) self.starts: Set[str] = set() self.extra = ex self.trans: List[str] = list() @@ -38,6 +45,15 @@ class GenResult: @property def transdef(self) -> str: s = self.trans + for name, rules in self.parse_rules.items(): + n = name.lstrip("?") + for equation in rules: + code = f""" + def {equation.name}(self, items: Any) -> Thunk[{equation.ty}]: + return Thunk(lambda: {equation.ty}()) + """.strip().replace("\n", "") + code = re.sub(r"\s+", " ", code) + s.append(code) if not s: s = ["pass"] return "\n" + "\n".join(map(lambda c: f" {c}", s)) @@ -48,8 +64,11 @@ class GenResult: for sym, pat in self.literals.items(): s.append(f"{sym}: {pat}") for name, rules in self.parse_rules.items(): - srules = " | ".join(rules) - s.append(f"{name}: {srules}") + names = [] + for rule in rules: + names.append(rule.name) + s.append(f"{rule.name}: {' '.join(rule.syms)}") + s.append(f"{name}: {' | '.join(names)}") s.append("%import common.WS") s.append("%ignore WS") return "\n".join(s) @@ -135,12 +154,16 @@ def gen(program: List[Decl]) -> GenResult: assert isinstance(node_desc, NodeDesc) res.starts.add(node_desc.name.lower()) - res.parse_rules[f"?{node_desc.name.lower()}"].append(node_desc.nonterminal) + # res.parse_rules[f"?{node_desc.name.lower()}"].append( + # ParseEquation( + # node_desc.name.lower(), [node_desc.nonterminal], node_desc.nonterminal + # ) + # ) class_decl = textwrap.dedent( f""" class {node_desc.nonterminal}: pass - """ + """ ) res.extra += class_decl @@ -153,12 +176,11 @@ def gen(program: List[Decl]) -> GenResult: class {v_class_name}({node_desc.nonterminal}): ''' ''' pass - """ + """ ) res.extra += class_decl prod_name = gensym(node_desc.nonterminal + "_") - res.parse_rules[node_desc.nonterminal].append(prod_name) print("PRODUCTION", prod_name, variant.prod) # resolving a production just means checking to make sure it's a @@ -188,7 +210,7 @@ def gen(program: List[Decl]) -> GenResult: for sym in variant.prod: n = resolve_production(sym) seq.append(n) - res.parse_rules[prod_name].append(" ".join(seq)) + res.parse_rules[node_desc.nonterminal].append(ParseEquation(prod_name, seq, v_class_name)) # create an environment for checking the equations based on # the production @@ -208,16 +230,4 @@ def gen(program: List[Decl]) -> GenResult: print("RHS", eq.rhs, eq.rhs.id) collect_required_thunks(copy.deepcopy(env), eq.rhs) - func_impl = textwrap.dedent( - f""" - def {eq_name}() -> None: - ''' {repr(eq)} ''' - pass - def {thunk_name}() -> Thunk[None]: - return Thunk({eq_name}) - """ - ) - print(f"```py\n{func_impl}\n```") - res.extra += func_impl - return res diff --git a/agmain.py b/agmain.py index a2c1302..e6908c5 100644 --- a/agmain.py +++ b/agmain.py @@ -16,16 +16,17 @@ if __name__ == "__main__": trans = Parser() ast = trans.transform(cst) - print("ast", ast) res = gen(ast) + print("Grammar:") + print(res.parser_data) if not os.path.exists("gen"): os.makedirs("gen") with open("gen/arith.py", "w") as f: fmt_str = textwrap.dedent( """ - # This documented generated by agtest. + # This document is generated by agtest. __all__ = ["parse"] from typing import Generic, TypeVar, Optional, Callable, Dict, Any @@ -34,6 +35,7 @@ if __name__ == "__main__": builtins: Dict[str, Any] = {{ "parseInt": lambda s: int(s) }} + class Thunk(Generic[T]): ''' A thunk represents a value that may be computed lazily. ''' value: Optional[T] @@ -44,9 +46,15 @@ if __name__ == "__main__": if self.value is None: self.value = self.func() return self.value - parser = Lark('''{pd}''', parser='lalr', start={starts}, debug=True) - class Trans(Transformer[None]): {transdef} + + parser = Lark(''' + {pd} + ''', parser='lalr', start={starts}, debug=True) + {ex} + + class Trans(Transformer[None]): {transdef} + def parse(input: str, start: Optional[str] = None) -> Any: tree = parser.parse(input, start) trans = Trans() diff --git a/test/agtest.ag b/test/agtest.ag new file mode 100644 index 0000000..5ca816f --- /dev/null +++ b/test/agtest.ag @@ -0,0 +1,9 @@ +iface Pycode { + pycode: str, +} + +node Program : Pycode { +} + +node Decl : Pycode { +} diff --git a/test/arith.ag b/test/arith.ag index 95b2143..4142e94 100644 --- a/test/arith.ag +++ b/test/arith.ag @@ -4,10 +4,13 @@ iface HasValue { node Expr : HasValue { "+" => { - self.val = l.val + r.val * l.val; + self.val = l.val + r.val; } + "*" => { self.val = l.val * r.val; } + => { self.val = parseInt(n); } } +