produce a parse tree
This commit is contained in:
parent
6e7563c84c
commit
500520b466
6 changed files with 119 additions and 23 deletions
7
Justfile
7
Justfile
|
@ -1,2 +1,7 @@
|
|||
watch:
|
||||
watchexec -ce py,lark,ag -i gen 'mypy *.py && python agmain.py && mypy gen/*.py'
|
||||
watchexec -ce py,lark,ag -i gen 'just run'
|
||||
|
||||
run:
|
||||
mypy *.py
|
||||
python agmain.py
|
||||
mypy gen/*.py
|
||||
|
|
33
agast.py
33
agast.py
|
@ -1,5 +1,27 @@
|
|||
from typing import *
|
||||
from lark import Transformer, Tree, Token
|
||||
import re
|
||||
from re import Pattern
|
||||
|
||||
def unescape(s: str) -> str:
|
||||
q = s[0]
|
||||
t = ""
|
||||
i = 1
|
||||
escaped = False
|
||||
while i < len(s):
|
||||
c = s[i]
|
||||
if escaped:
|
||||
if c == q: t += q
|
||||
elif c == 'n': t += '\n'
|
||||
elif c == 't': t += '\t'
|
||||
if c == q: break
|
||||
if c == '\\':
|
||||
escaped = True
|
||||
i += 1
|
||||
continue
|
||||
t += c
|
||||
i += 1
|
||||
return t
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -40,8 +62,16 @@ class NodeRefByName(NodeRef, str):
|
|||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
def __repr__(self) -> str: return f"NodeRefByName({self.name})"
|
||||
class NodeRegex(NodeRef):
|
||||
def __init__(self, pat: str):
|
||||
self.pat = re.compile(unescape(pat))
|
||||
def __repr__(self) -> str: return f"NodeRegex({self.pat.pattern})"
|
||||
|
||||
class Sym: pass
|
||||
class SymLit(Sym):
|
||||
def __init__(self, s: str):
|
||||
self.lit = unescape(s)
|
||||
def __repr__(self) -> str: return f"SymLit({repr(self.lit)})"
|
||||
class SymRename(Sym):
|
||||
def __init__(self, name: str, ty: NodeRef):
|
||||
self.name = name
|
||||
|
@ -113,6 +143,7 @@ class Parser(Transformer[List[Decl]]):
|
|||
[name, ifaces, variants] = items
|
||||
return Node(name, ifaces, variants)
|
||||
def node_ref_name(self, items: List[str]) -> NodeRefByName: return NodeRefByName(items[0])
|
||||
def node_regex(self, items: List[str]) -> NodeRegex: return NodeRegex(items[0])
|
||||
|
||||
# variants
|
||||
def variants(self, items: List[Variant]) -> List[Variant]: return items
|
||||
|
@ -121,6 +152,8 @@ class Parser(Transformer[List[Decl]]):
|
|||
return Variant(prod, equations)
|
||||
def prod(self, items: List[Sym]) -> List[Sym]: return items
|
||||
|
||||
# symbols in productions
|
||||
def sym_lit(self, items: List[str]) -> Sym: return SymLit(items[0])
|
||||
def sym_rename(self, items: List[Any]) -> Sym: return SymRename(items[0], items[1])
|
||||
|
||||
# equations
|
||||
|
|
78
aggen.py
78
aggen.py
|
@ -2,6 +2,9 @@ from typing import *
|
|||
import textwrap
|
||||
import re
|
||||
import copy
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from re import Pattern
|
||||
|
||||
from agast import *
|
||||
|
||||
|
@ -10,17 +13,33 @@ i = 0
|
|||
|
||||
class GenResult:
|
||||
def __init__(self, pd: str = "", ex: str = ""):
|
||||
self.parser_data = pd
|
||||
self.literals: Dict[str, str] = dict()
|
||||
self.parse_rules: defaultdict[str, List[str]] = defaultdict(list)
|
||||
self.starts: Set[str] = set()
|
||||
self.extra = ex
|
||||
|
||||
@property
|
||||
def parser_data(self) -> str:
|
||||
s = []
|
||||
for sym, pat in self.literals.items():
|
||||
s.append(f"{sym}: {pat}")
|
||||
for name, rules in self.parse_rules.items():
|
||||
srules = " | ".join(rules)
|
||||
s.append(f"{name}: {srules}")
|
||||
s.append("%import common.WS")
|
||||
s.append("%ignore WS")
|
||||
return "\n".join(s)
|
||||
|
||||
def gen(program: List[Decl]) -> GenResult:
|
||||
res = GenResult()
|
||||
|
||||
def gen(prefix: str = "", suffix: str = "") -> str:
|
||||
global i
|
||||
presan = re.sub("[^0-9a-zA-Z]+", "_", prefix)
|
||||
sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix)
|
||||
i += 1
|
||||
return f"{presan}{i}{sufsan}"
|
||||
|
||||
def v(name: str) -> str:
|
||||
return f"__ag_{name}"
|
||||
|
||||
|
@ -34,7 +53,7 @@ def gen(program: List[Decl]) -> GenResult:
|
|||
map(lambda c: (c.name, cast(Iface, c)),
|
||||
filter(lambda c: isinstance(c, Iface),
|
||||
program)))
|
||||
|
||||
|
||||
# list of node -> iface mappings
|
||||
what_ifaces: Dict[str, Set[str]] = dict()
|
||||
what_fields: Dict[str, Dict[str, str]] = dict()
|
||||
|
@ -83,27 +102,62 @@ def gen(program: List[Decl]) -> GenResult:
|
|||
return dict()
|
||||
raise Exception(f"unhandled {expr.__class__}")
|
||||
|
||||
for node in filter(lambda c: isinstance(c, Node), program):
|
||||
node_map = dict(map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program)))
|
||||
node_name_map = dict(map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items()))
|
||||
|
||||
for node in node_map.values():
|
||||
node = cast(Node, node)
|
||||
n_class_name = gen(node.name)
|
||||
node_name_lower = node.name.lower()
|
||||
node_name = node_name_map[node.name]
|
||||
res.parse_rules[f"?{node_name_lower}"].append(node_name)
|
||||
res.starts.add(node_name_lower)
|
||||
|
||||
class_decl = textwrap.dedent(f"""
|
||||
class {v(n_class_name)}: pass
|
||||
class {v(node_name)}: pass
|
||||
""")
|
||||
res.extra += class_decl
|
||||
|
||||
print(node.name, node.ifaces)
|
||||
|
||||
for variant in node.variants:
|
||||
v_class_name = gen(f"{n_class_name}_var")
|
||||
v_class_name = gen(f"{node_name}_var")
|
||||
class_decl = textwrap.dedent(f"""
|
||||
class {v(v_class_name)}({v(n_class_name)}):
|
||||
class {v(v_class_name)}({v(node_name)}):
|
||||
''' '''
|
||||
pass
|
||||
""")
|
||||
res.extra += class_decl
|
||||
|
||||
prod_name = gen(node.name)
|
||||
print(prod_name)
|
||||
prod_name = gen(node_name + "_")
|
||||
res.parse_rules[node_name].append(prod_name)
|
||||
print("PRODUCTION", prod_name, variant.prod)
|
||||
|
||||
# resolving a production just means checking to make sure it's a
|
||||
# type that exists or it's a regex
|
||||
def resolve_production(sym: Sym) -> str:
|
||||
print(f"resolve_production({sym})")
|
||||
if isinstance(sym, SymRename):
|
||||
if isinstance(sym.ty, NodeRefByName):
|
||||
if sym.ty.name in node_name_map:
|
||||
return node_name_map[sym.ty.name]
|
||||
else:
|
||||
raise Exception(f"unresolved name {sym.ty.name} in production")
|
||||
elif isinstance(sym.ty, NodeRegex):
|
||||
sym_name = gen("sym")
|
||||
res.literals[sym_name] = f"/{sym.ty.pat.pattern}/"
|
||||
return sym_name
|
||||
elif isinstance(sym, SymLit):
|
||||
sym_name = gen("lit")
|
||||
# hack to make repr have double quotes
|
||||
res.literals[sym_name] = json.dumps(sym.lit)
|
||||
return sym_name
|
||||
raise Exception(f"unhandled {sym.__class__}")
|
||||
|
||||
seq = []
|
||||
for sym in variant.prod:
|
||||
n = resolve_production(sym)
|
||||
seq.append(n)
|
||||
res.parse_rules[prod_name].append(" ".join(seq))
|
||||
|
||||
# create an environment for checking the equations based on
|
||||
# the production
|
||||
|
@ -112,7 +166,7 @@ def gen(program: List[Decl]) -> GenResult:
|
|||
if isinstance(sym, SymRename):
|
||||
env.append((sym.name, sym.ty))
|
||||
print(env)
|
||||
|
||||
|
||||
# for each of the equations, find out what the equation is
|
||||
# trying to compute, and generate a thunk corresponding to
|
||||
# that value.
|
||||
|
@ -135,6 +189,6 @@ def gen(program: List[Decl]) -> GenResult:
|
|||
|
||||
# this is a "type alias" that connects it to one of the generated
|
||||
# names above
|
||||
res.extra += f"{node.name} = {v(n_class_name)}"
|
||||
res.extra += f"{node.name} = {v(node_name)}"
|
||||
|
||||
return res
|
||||
return res
|
||||
|
|
12
agmain.py
12
agmain.py
|
@ -41,15 +41,15 @@ if __name__ == "__main__":
|
|||
if self.value is None:
|
||||
self.value = self.func()
|
||||
return self.value
|
||||
parser = Lark('''start:
|
||||
{pd}''')
|
||||
parser = Lark('''{pd}''', parser='lalr', start={starts}, debug=True)
|
||||
class Trans(Transformer[None]):
|
||||
pass
|
||||
{ex}
|
||||
def parse(input: str) -> None:
|
||||
print(input)
|
||||
def parse(input: str, start: Optional[str] = None) -> Any:
|
||||
return parser.parse(input, start)
|
||||
""")
|
||||
f.write(fmt_str.format(pd=res.parser_data, ex=res.extra))
|
||||
f.write(fmt_str.format(pd=res.parser_data, ex=res.extra, starts=list(res.starts)))
|
||||
|
||||
mod = importlib.import_module("gen.arith")
|
||||
mod.parse("1 + 2 * 3") # type: ignore
|
||||
print(mod.parse("1 + 2 * 3", start="expr")) # type: ignore
|
||||
|
||||
|
|
2
arith.ag
2
arith.ag
|
@ -9,5 +9,5 @@ node Expr : HasValue {
|
|||
<l:Expr> "*" <r:Expr> => {
|
||||
self.val = l.val * r.val;
|
||||
}
|
||||
<n:r"[0-9]+"> => { self.val = parseInt(n); }
|
||||
<n:"[0-9]+"> => { self.val = parseInt(n); }
|
||||
}
|
||||
|
|
10
grammar.lark
10
grammar.lark
|
@ -18,11 +18,13 @@ variants: variant*
|
|||
variant: prod "=>" "{" equations "}"
|
||||
prod: sym*
|
||||
?sym: sym_rename
|
||||
| STRING
|
||||
| sym_lit
|
||||
sym_lit: ESCAPED_STRING
|
||||
sym_rename: "<" ident ":" node_ref ">"
|
||||
?node_ref: node_ref_name
|
||||
| STRING
|
||||
| node_regex
|
||||
node_ref_name: ident
|
||||
node_regex: ESCAPED_STRING
|
||||
equations: equation_semi*
|
||||
equation_semi: equation ";"
|
||||
// TODO: the left side should really be a separate type
|
||||
|
@ -56,5 +58,7 @@ IDENT: /([a-zA-Z][a-zA-Z0-9_]*)|(_[a-zA-Z0-9_]+)/
|
|||
|
||||
%import python.STRING
|
||||
%import common.WS
|
||||
%import common.ESCAPED_STRING
|
||||
|
||||
%ignore WS
|
||||
%ignore COMMENT
|
||||
%ignore COMMENT
|
||||
|
|
Loading…
Reference in a new issue