produce a parse tree

This commit is contained in:
Michael Zhang 2021-09-30 15:32:32 -05:00
parent 6e7563c84c
commit 500520b466
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B
6 changed files with 119 additions and 23 deletions

View file

@ -1,2 +1,7 @@
watch:
watchexec -ce py,lark,ag -i gen 'mypy *.py && python agmain.py && mypy gen/*.py'
watchexec -ce py,lark,ag -i gen 'just run'
run:
mypy *.py
python agmain.py
mypy gen/*.py

View file

@ -1,5 +1,27 @@
from typing import *
from lark import Transformer, Tree, Token
import re
from re import Pattern
def unescape(s: str) -> str:
q = s[0]
t = ""
i = 1
escaped = False
while i < len(s):
c = s[i]
if escaped:
if c == q: t += q
elif c == 'n': t += '\n'
elif c == 't': t += '\t'
if c == q: break
if c == '\\':
escaped = True
i += 1
continue
t += c
i += 1
return t
T = TypeVar("T")
@ -40,8 +62,16 @@ class NodeRefByName(NodeRef, str):
def __init__(self, name: str):
self.name = name
def __repr__(self) -> str: return f"NodeRefByName({self.name})"
class NodeRegex(NodeRef):
def __init__(self, pat: str):
self.pat = re.compile(unescape(pat))
def __repr__(self) -> str: return f"NodeRegex({self.pat.pattern})"
class Sym: pass
class SymLit(Sym):
def __init__(self, s: str):
self.lit = unescape(s)
def __repr__(self) -> str: return f"SymLit({repr(self.lit)})"
class SymRename(Sym):
def __init__(self, name: str, ty: NodeRef):
self.name = name
@ -113,6 +143,7 @@ class Parser(Transformer[List[Decl]]):
[name, ifaces, variants] = items
return Node(name, ifaces, variants)
def node_ref_name(self, items: List[str]) -> NodeRefByName: return NodeRefByName(items[0])
def node_regex(self, items: List[str]) -> NodeRegex: return NodeRegex(items[0])
# variants
def variants(self, items: List[Variant]) -> List[Variant]: return items
@ -121,6 +152,8 @@ class Parser(Transformer[List[Decl]]):
return Variant(prod, equations)
def prod(self, items: List[Sym]) -> List[Sym]: return items
# symbols in productions
def sym_lit(self, items: List[str]) -> Sym: return SymLit(items[0])
def sym_rename(self, items: List[Any]) -> Sym: return SymRename(items[0], items[1])
# equations

View file

@ -2,6 +2,9 @@ from typing import *
import textwrap
import re
import copy
import json
from collections import defaultdict
from re import Pattern
from agast import *
@ -10,17 +13,33 @@ i = 0
class GenResult:
def __init__(self, pd: str = "", ex: str = ""):
self.parser_data = pd
self.literals: Dict[str, str] = dict()
self.parse_rules: defaultdict[str, List[str]] = defaultdict(list)
self.starts: Set[str] = set()
self.extra = ex
@property
def parser_data(self) -> str:
s = []
for sym, pat in self.literals.items():
s.append(f"{sym}: {pat}")
for name, rules in self.parse_rules.items():
srules = " | ".join(rules)
s.append(f"{name}: {srules}")
s.append("%import common.WS")
s.append("%ignore WS")
return "\n".join(s)
def gen(program: List[Decl]) -> GenResult:
res = GenResult()
def gen(prefix: str = "", suffix: str = "") -> str:
global i
presan = re.sub("[^0-9a-zA-Z]+", "_", prefix)
sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix)
i += 1
return f"{presan}{i}{sufsan}"
def v(name: str) -> str:
return f"__ag_{name}"
@ -34,7 +53,7 @@ def gen(program: List[Decl]) -> GenResult:
map(lambda c: (c.name, cast(Iface, c)),
filter(lambda c: isinstance(c, Iface),
program)))
# list of node -> iface mappings
what_ifaces: Dict[str, Set[str]] = dict()
what_fields: Dict[str, Dict[str, str]] = dict()
@ -83,27 +102,62 @@ def gen(program: List[Decl]) -> GenResult:
return dict()
raise Exception(f"unhandled {expr.__class__}")
for node in filter(lambda c: isinstance(c, Node), program):
node_map = dict(map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program)))
node_name_map = dict(map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items()))
for node in node_map.values():
node = cast(Node, node)
n_class_name = gen(node.name)
node_name_lower = node.name.lower()
node_name = node_name_map[node.name]
res.parse_rules[f"?{node_name_lower}"].append(node_name)
res.starts.add(node_name_lower)
class_decl = textwrap.dedent(f"""
class {v(n_class_name)}: pass
class {v(node_name)}: pass
""")
res.extra += class_decl
print(node.name, node.ifaces)
for variant in node.variants:
v_class_name = gen(f"{n_class_name}_var")
v_class_name = gen(f"{node_name}_var")
class_decl = textwrap.dedent(f"""
class {v(v_class_name)}({v(n_class_name)}):
class {v(v_class_name)}({v(node_name)}):
''' '''
pass
""")
res.extra += class_decl
prod_name = gen(node.name)
print(prod_name)
prod_name = gen(node_name + "_")
res.parse_rules[node_name].append(prod_name)
print("PRODUCTION", prod_name, variant.prod)
# resolving a production just means checking to make sure it's a
# type that exists or it's a regex
def resolve_production(sym: Sym) -> str:
print(f"resolve_production({sym})")
if isinstance(sym, SymRename):
if isinstance(sym.ty, NodeRefByName):
if sym.ty.name in node_name_map:
return node_name_map[sym.ty.name]
else:
raise Exception(f"unresolved name {sym.ty.name} in production")
elif isinstance(sym.ty, NodeRegex):
sym_name = gen("sym")
res.literals[sym_name] = f"/{sym.ty.pat.pattern}/"
return sym_name
elif isinstance(sym, SymLit):
sym_name = gen("lit")
# hack to make repr have double quotes
res.literals[sym_name] = json.dumps(sym.lit)
return sym_name
raise Exception(f"unhandled {sym.__class__}")
seq = []
for sym in variant.prod:
n = resolve_production(sym)
seq.append(n)
res.parse_rules[prod_name].append(" ".join(seq))
# create an environment for checking the equations based on
# the production
@ -112,7 +166,7 @@ def gen(program: List[Decl]) -> GenResult:
if isinstance(sym, SymRename):
env.append((sym.name, sym.ty))
print(env)
# for each of the equations, find out what the equation is
# trying to compute, and generate a thunk corresponding to
# that value.
@ -135,6 +189,6 @@ def gen(program: List[Decl]) -> GenResult:
# this is a "type alias" that connects it to one of the generated
# names above
res.extra += f"{node.name} = {v(n_class_name)}"
res.extra += f"{node.name} = {v(node_name)}"
return res
return res

View file

@ -41,15 +41,15 @@ if __name__ == "__main__":
if self.value is None:
self.value = self.func()
return self.value
parser = Lark('''start:
{pd}''')
parser = Lark('''{pd}''', parser='lalr', start={starts}, debug=True)
class Trans(Transformer[None]):
pass
{ex}
def parse(input: str) -> None:
print(input)
def parse(input: str, start: Optional[str] = None) -> Any:
return parser.parse(input, start)
""")
f.write(fmt_str.format(pd=res.parser_data, ex=res.extra))
f.write(fmt_str.format(pd=res.parser_data, ex=res.extra, starts=list(res.starts)))
mod = importlib.import_module("gen.arith")
mod.parse("1 + 2 * 3") # type: ignore
print(mod.parse("1 + 2 * 3", start="expr")) # type: ignore

View file

@ -9,5 +9,5 @@ node Expr : HasValue {
<l:Expr> "*" <r:Expr> => {
self.val = l.val * r.val;
}
<n:r"[0-9]+"> => { self.val = parseInt(n); }
<n:"[0-9]+"> => { self.val = parseInt(n); }
}

View file

@ -18,11 +18,13 @@ variants: variant*
variant: prod "=>" "{" equations "}"
prod: sym*
?sym: sym_rename
| STRING
| sym_lit
sym_lit: ESCAPED_STRING
sym_rename: "<" ident ":" node_ref ">"
?node_ref: node_ref_name
| STRING
| node_regex
node_ref_name: ident
node_regex: ESCAPED_STRING
equations: equation_semi*
equation_semi: equation ";"
// TODO: the left side should really be a separate type
@ -56,5 +58,7 @@ IDENT: /([a-zA-Z][a-zA-Z0-9_]*)|(_[a-zA-Z0-9_]+)/
%import python.STRING
%import common.WS
%import common.ESCAPED_STRING
%ignore WS
%ignore COMMENT
%ignore COMMENT