From 500520b466e07ab7c80ffac79ff20eae1cd51727 Mon Sep 17 00:00:00 2001 From: Michael Zhang Date: Thu, 30 Sep 2021 15:32:32 -0500 Subject: [PATCH] produce a parse tree --- Justfile | 7 ++++- agast.py | 33 ++++++++++++++++++++++ aggen.py | 78 ++++++++++++++++++++++++++++++++++++++++++++-------- agmain.py | 12 ++++---- arith.ag | 2 +- grammar.lark | 10 +++++-- 6 files changed, 119 insertions(+), 23 deletions(-) diff --git a/Justfile b/Justfile index 49f7016..5592315 100644 --- a/Justfile +++ b/Justfile @@ -1,2 +1,7 @@ watch: - watchexec -ce py,lark,ag -i gen 'mypy *.py && python agmain.py && mypy gen/*.py' + watchexec -ce py,lark,ag -i gen 'just run' + +run: + mypy *.py + python agmain.py + mypy gen/*.py diff --git a/agast.py b/agast.py index 2e7e496..c92fa51 100644 --- a/agast.py +++ b/agast.py @@ -1,5 +1,27 @@ from typing import * from lark import Transformer, Tree, Token +import re +from re import Pattern + +def unescape(s: str) -> str: + q = s[0] + t = "" + i = 1 + escaped = False + while i < len(s): + c = s[i] + if escaped: + if c == q: t += q + elif c == 'n': t += '\n' + elif c == 't': t += '\t' + if c == q: break + if c == '\\': + escaped = True + i += 1 + continue + t += c + i += 1 + return t T = TypeVar("T") @@ -40,8 +62,16 @@ class NodeRefByName(NodeRef, str): def __init__(self, name: str): self.name = name def __repr__(self) -> str: return f"NodeRefByName({self.name})" +class NodeRegex(NodeRef): + def __init__(self, pat: str): + self.pat = re.compile(unescape(pat)) + def __repr__(self) -> str: return f"NodeRegex({self.pat.pattern})" class Sym: pass +class SymLit(Sym): + def __init__(self, s: str): + self.lit = unescape(s) + def __repr__(self) -> str: return f"SymLit({repr(self.lit)})" class SymRename(Sym): def __init__(self, name: str, ty: NodeRef): self.name = name @@ -113,6 +143,7 @@ class Parser(Transformer[List[Decl]]): [name, ifaces, variants] = items return Node(name, ifaces, variants) def node_ref_name(self, items: List[str]) -> NodeRefByName: return NodeRefByName(items[0]) + def node_regex(self, items: List[str]) -> NodeRegex: return NodeRegex(items[0]) # variants def variants(self, items: List[Variant]) -> List[Variant]: return items @@ -121,6 +152,8 @@ class Parser(Transformer[List[Decl]]): return Variant(prod, equations) def prod(self, items: List[Sym]) -> List[Sym]: return items + # symbols in productions + def sym_lit(self, items: List[str]) -> Sym: return SymLit(items[0]) def sym_rename(self, items: List[Any]) -> Sym: return SymRename(items[0], items[1]) # equations diff --git a/aggen.py b/aggen.py index e7db731..8b33d7a 100644 --- a/aggen.py +++ b/aggen.py @@ -2,6 +2,9 @@ from typing import * import textwrap import re import copy +import json +from collections import defaultdict +from re import Pattern from agast import * @@ -10,17 +13,33 @@ i = 0 class GenResult: def __init__(self, pd: str = "", ex: str = ""): - self.parser_data = pd + self.literals: Dict[str, str] = dict() + self.parse_rules: defaultdict[str, List[str]] = defaultdict(list) + self.starts: Set[str] = set() self.extra = ex + @property + def parser_data(self) -> str: + s = [] + for sym, pat in self.literals.items(): + s.append(f"{sym}: {pat}") + for name, rules in self.parse_rules.items(): + srules = " | ".join(rules) + s.append(f"{name}: {srules}") + s.append("%import common.WS") + s.append("%ignore WS") + return "\n".join(s) + def gen(program: List[Decl]) -> GenResult: res = GenResult() + def gen(prefix: str = "", suffix: str = "") -> str: global i presan = re.sub("[^0-9a-zA-Z]+", "_", prefix) sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix) i += 1 return f"{presan}{i}{sufsan}" + def v(name: str) -> str: return f"__ag_{name}" @@ -34,7 +53,7 @@ def gen(program: List[Decl]) -> GenResult: map(lambda c: (c.name, cast(Iface, c)), filter(lambda c: isinstance(c, Iface), program))) - + # list of node -> iface mappings what_ifaces: Dict[str, Set[str]] = dict() what_fields: Dict[str, Dict[str, str]] = dict() @@ -83,27 +102,62 @@ def gen(program: List[Decl]) -> GenResult: return dict() raise Exception(f"unhandled {expr.__class__}") - for node in filter(lambda c: isinstance(c, Node), program): + node_map = dict(map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program))) + node_name_map = dict(map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items())) + + for node in node_map.values(): node = cast(Node, node) - n_class_name = gen(node.name) + node_name_lower = node.name.lower() + node_name = node_name_map[node.name] + res.parse_rules[f"?{node_name_lower}"].append(node_name) + res.starts.add(node_name_lower) + class_decl = textwrap.dedent(f""" - class {v(n_class_name)}: pass + class {v(node_name)}: pass """) res.extra += class_decl print(node.name, node.ifaces) for variant in node.variants: - v_class_name = gen(f"{n_class_name}_var") + v_class_name = gen(f"{node_name}_var") class_decl = textwrap.dedent(f""" - class {v(v_class_name)}({v(n_class_name)}): + class {v(v_class_name)}({v(node_name)}): ''' ''' pass """) res.extra += class_decl - prod_name = gen(node.name) - print(prod_name) + prod_name = gen(node_name + "_") + res.parse_rules[node_name].append(prod_name) + print("PRODUCTION", prod_name, variant.prod) + + # resolving a production just means checking to make sure it's a + # type that exists or it's a regex + def resolve_production(sym: Sym) -> str: + print(f"resolve_production({sym})") + if isinstance(sym, SymRename): + if isinstance(sym.ty, NodeRefByName): + if sym.ty.name in node_name_map: + return node_name_map[sym.ty.name] + else: + raise Exception(f"unresolved name {sym.ty.name} in production") + elif isinstance(sym.ty, NodeRegex): + sym_name = gen("sym") + res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" + return sym_name + elif isinstance(sym, SymLit): + sym_name = gen("lit") + # hack to make repr have double quotes + res.literals[sym_name] = json.dumps(sym.lit) + return sym_name + raise Exception(f"unhandled {sym.__class__}") + + seq = [] + for sym in variant.prod: + n = resolve_production(sym) + seq.append(n) + res.parse_rules[prod_name].append(" ".join(seq)) # create an environment for checking the equations based on # the production @@ -112,7 +166,7 @@ def gen(program: List[Decl]) -> GenResult: if isinstance(sym, SymRename): env.append((sym.name, sym.ty)) print(env) - + # for each of the equations, find out what the equation is # trying to compute, and generate a thunk corresponding to # that value. @@ -135,6 +189,6 @@ def gen(program: List[Decl]) -> GenResult: # this is a "type alias" that connects it to one of the generated # names above - res.extra += f"{node.name} = {v(n_class_name)}" + res.extra += f"{node.name} = {v(node_name)}" - return res \ No newline at end of file + return res diff --git a/agmain.py b/agmain.py index ff08031..19056d9 100644 --- a/agmain.py +++ b/agmain.py @@ -41,15 +41,15 @@ if __name__ == "__main__": if self.value is None: self.value = self.func() return self.value - parser = Lark('''start: - {pd}''') + parser = Lark('''{pd}''', parser='lalr', start={starts}, debug=True) class Trans(Transformer[None]): pass {ex} - def parse(input: str) -> None: - print(input) + def parse(input: str, start: Optional[str] = None) -> Any: + return parser.parse(input, start) """) - f.write(fmt_str.format(pd=res.parser_data, ex=res.extra)) + f.write(fmt_str.format(pd=res.parser_data, ex=res.extra, starts=list(res.starts))) mod = importlib.import_module("gen.arith") - mod.parse("1 + 2 * 3") # type: ignore + print(mod.parse("1 + 2 * 3", start="expr")) # type: ignore + diff --git a/arith.ag b/arith.ag index f592fcd..95b2143 100644 --- a/arith.ag +++ b/arith.ag @@ -9,5 +9,5 @@ node Expr : HasValue { "*" => { self.val = l.val * r.val; } - => { self.val = parseInt(n); } + => { self.val = parseInt(n); } } diff --git a/grammar.lark b/grammar.lark index a876b6c..4fc70c6 100644 --- a/grammar.lark +++ b/grammar.lark @@ -18,11 +18,13 @@ variants: variant* variant: prod "=>" "{" equations "}" prod: sym* ?sym: sym_rename - | STRING + | sym_lit +sym_lit: ESCAPED_STRING sym_rename: "<" ident ":" node_ref ">" ?node_ref: node_ref_name - | STRING + | node_regex node_ref_name: ident +node_regex: ESCAPED_STRING equations: equation_semi* equation_semi: equation ";" // TODO: the left side should really be a separate type @@ -56,5 +58,7 @@ IDENT: /([a-zA-Z][a-zA-Z0-9_]*)|(_[a-zA-Z0-9_]+)/ %import python.STRING %import common.WS +%import common.ESCAPED_STRING + %ignore WS -%ignore COMMENT \ No newline at end of file +%ignore COMMENT