This commit is contained in:
Michael Zhang 2021-09-30 16:07:36 -05:00
parent 18df271749
commit 61bfd16b81
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B
3 changed files with 236 additions and 216 deletions

View file

@ -1,6 +1,9 @@
watch: watch:
watchexec -ce py,lark,ag -i gen 'just run' watchexec -ce py,lark,ag -i gen 'just run'
fmt:
pipenv run black .
run: run:
mypy *.py mypy *.py
python agmain.py python agmain.py

432
aggen.py
View file

@ -1,212 +1,220 @@
from typing import * from typing import *
import textwrap import textwrap
import re import re
import copy import copy
import json import json
from collections import defaultdict from collections import defaultdict
from re import Pattern from re import Pattern
from agast import * from agast import *
global i global i
i = 0 i = 0
class GenResult: class GenResult:
def __init__(self, pd: str = "", ex: str = ""): def __init__(self, pd: str = "", ex: str = ""):
self.literals: Dict[str, str] = dict() self.literals: Dict[str, str] = dict()
self.parse_rules: defaultdict[str, List[str]] = defaultdict(list) self.parse_rules: defaultdict[str, List[str]] = defaultdict(list)
self.starts: Set[str] = set() self.starts: Set[str] = set()
self.extra = ex self.extra = ex
self.trans: List[str] = list()
@property
def parser_data(self) -> str: @property
s = [] def transdef(self) -> str:
for sym, pat in self.literals.items(): s = self.trans
s.append(f"{sym}: {pat}") if not s:
for name, rules in self.parse_rules.items(): s = ["pass"]
srules = " | ".join(rules) return "\n" + "\n".join(map(lambda c: f" {c}", s))
s.append(f"{name}: {srules}")
s.append("%import common.WS") @property
s.append("%ignore WS") def parser_data(self) -> str:
return "\n".join(s) s = []
for sym, pat in self.literals.items():
s.append(f"{sym}: {pat}")
def gen(program: List[Decl]) -> GenResult: for name, rules in self.parse_rules.items():
res = GenResult() srules = " | ".join(rules)
s.append(f"{name}: {srules}")
def gen(prefix: str = "", suffix: str = "") -> str: s.append("%import common.WS")
global i s.append("%ignore WS")
presan = re.sub("[^0-9a-zA-Z]+", "_", prefix) return "\n".join(s)
sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix)
i += 1
return f"{presan}{i}{sufsan}" def gen(program: List[Decl]) -> GenResult:
res = GenResult()
def v(name: str) -> str:
return f"__ag_{name}" def gen(prefix: str = "", suffix: str = "") -> str:
global i
# builtins presan = re.sub("[^0-9a-zA-Z]+", "_", prefix)
builtins: Dict[str, str] = { sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix)
"parseInt": "", i += 1
} return f"{presan}{i}{sufsan}"
# collect a list of name -> iface declarations def v(name: str) -> str:
ifaces: Dict[str, Iface] = dict( return f"__ag_{name}"
map(
lambda c: (c.name, cast(Iface, c)), # builtins
filter(lambda c: isinstance(c, Iface), program), builtins: Dict[str, str] = {
) "parseInt": "",
) }
# list of node -> iface mappings # collect a list of name -> iface declarations
what_ifaces: Dict[str, Set[str]] = dict() ifaces: Dict[str, Iface] = dict(
what_fields: Dict[str, Dict[str, str]] = dict() map(
for node in filter(lambda c: isinstance(c, Node), program): lambda c: (c.name, cast(Iface, c)),
node = cast(Node, node) filter(lambda c: isinstance(c, Iface), program),
# all_fields = dict() )
what_ifaces[node.name] = set(node.ifaces) )
this_fields = dict()
for iface in node.ifaces: # list of node -> iface mappings
fields = ifaces[iface].fields what_ifaces: Dict[str, Set[str]] = dict()
for field in fields: what_fields: Dict[str, Dict[str, str]] = dict()
if field.name in this_fields: for node in filter(lambda c: isinstance(c, Node), program):
raise Exception("duplicate field name") node = cast(Node, node)
this_fields[field.name] = field.ty # all_fields = dict()
what_fields[node.name] = this_fields what_ifaces[node.name] = set(node.ifaces)
print("what_ifaces:", what_ifaces) this_fields = dict()
print("what_fields:", what_fields) for iface in node.ifaces:
fields = ifaces[iface].fields
# a high-level dictionary of productions; this has sub-productions for field in fields:
# that should be further expanded at a later step before converting if field.name in this_fields:
# into lark code raise Exception("duplicate field name")
productions_hi: Dict[str, Union[str, List[str]]] = dict() this_fields[field.name] = field.ty
what_fields[node.name] = this_fields
# TODO: this should probably not be inlined here, but i'll move it print("what_ifaces:", what_ifaces)
# out once i get more info into the 'env' print("what_fields:", what_fields)
def collect_required_thunks(
env: List[Tuple[str, NodeRef]], expr: Expr # a high-level dictionary of productions; this has sub-productions
) -> Dict[str, str]: # that should be further expanded at a later step before converting
names = dict(env) # into lark code
print(f"collect_required_thunks({expr})", expr.__class__) productions_hi: Dict[str, Union[str, List[str]]] = dict()
if isinstance(expr, ExprDot):
return collect_required_thunks(env, expr.left) # TODO: this should probably not be inlined here, but i'll move it
elif isinstance(expr, ExprMul): # out once i get more info into the 'env'
a = collect_required_thunks(env, expr.left) def collect_required_thunks(
b = collect_required_thunks(env, expr.right) env: List[Tuple[str, NodeRef]], expr: Expr
a.update(b) ) -> Dict[str, str]:
return a names = dict(env)
elif isinstance(expr, ExprAdd): print(f"collect_required_thunks({expr})", expr.__class__)
a = collect_required_thunks(env, expr.left) if isinstance(expr, ExprDot):
b = collect_required_thunks(env, expr.right) return collect_required_thunks(env, expr.left)
a.update(b) elif isinstance(expr, ExprMul):
return a a = collect_required_thunks(env, expr.left)
elif isinstance(expr, ExprCall): b = collect_required_thunks(env, expr.right)
return collect_required_thunks(env, expr.func) a.update(b)
elif isinstance(expr, ExprName): return a
if expr.name not in names and expr.name not in builtins: elif isinstance(expr, ExprAdd):
raise Exception(f"unbound name '{expr.name}'") a = collect_required_thunks(env, expr.left)
return dict() b = collect_required_thunks(env, expr.right)
raise Exception(f"unhandled {expr.__class__}") a.update(b)
return a
node_map = dict( elif isinstance(expr, ExprCall):
map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program)) return collect_required_thunks(env, expr.func)
) elif isinstance(expr, ExprName):
node_name_map = dict( if expr.name not in names and expr.name not in builtins:
map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items()) raise Exception(f"unbound name '{expr.name}'")
) return dict()
raise Exception(f"unhandled {expr.__class__}")
for node in node_map.values():
node = cast(Node, node) node_map = dict(
node_name_lower = node.name.lower() map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program))
node_name = node_name_map[node.name] )
res.parse_rules[f"?{node_name_lower}"].append(node_name) node_name_map = dict(
res.starts.add(node_name_lower) map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items())
)
class_decl = textwrap.dedent(
f""" for node in node_map.values():
class {v(node_name)}: pass node = cast(Node, node)
""" node_name_lower = node.name.lower()
) node_name = node_name_map[node.name]
res.extra += class_decl res.parse_rules[f"?{node_name_lower}"].append(node_name)
res.starts.add(node_name_lower)
print(node.name, node.ifaces)
class_decl = textwrap.dedent(
for variant in node.variants: f"""
v_class_name = gen(f"{node_name}_var") class {v(node_name)}: pass
class_decl = textwrap.dedent( """
f""" )
class {v(v_class_name)}({v(node_name)}): res.extra += class_decl
''' '''
pass print(node.name, node.ifaces)
"""
) for variant in node.variants:
res.extra += class_decl v_class_name = gen(f"{node_name}_var")
class_decl = textwrap.dedent(
prod_name = gen(node_name + "_") f"""
res.parse_rules[node_name].append(prod_name) class {v(v_class_name)}({v(node_name)}):
print("PRODUCTION", prod_name, variant.prod) ''' '''
pass
# resolving a production just means checking to make sure it's a """
# type that exists or it's a regex )
def resolve_production(sym: Sym) -> str: res.extra += class_decl
print(f"resolve_production({sym})")
if isinstance(sym, SymRename): prod_name = gen(node_name + "_")
if isinstance(sym.ty, NodeRefByName): res.parse_rules[node_name].append(prod_name)
if sym.ty.name in node_name_map: print("PRODUCTION", prod_name, variant.prod)
return node_name_map[sym.ty.name]
else: # resolving a production just means checking to make sure it's a
raise Exception( # type that exists or it's a regex
f"unresolved name {sym.ty.name} in production" def resolve_production(sym: Sym) -> str:
) print(f"resolve_production({sym})")
elif isinstance(sym.ty, NodeRegex): if isinstance(sym, SymRename):
sym_name = gen("sym") if isinstance(sym.ty, NodeRefByName):
res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" if sym.ty.name in node_name_map:
return sym_name return node_name_map[sym.ty.name]
elif isinstance(sym, SymLit): else:
sym_name = gen("lit") raise Exception(
# hack to make repr have double quotes f"unresolved name {sym.ty.name} in production"
res.literals[sym_name] = json.dumps(sym.lit) )
return sym_name elif isinstance(sym.ty, NodeRegex):
raise Exception(f"unhandled {sym.__class__}") sym_name = gen("sym")
res.literals[sym_name] = f"/{sym.ty.pat.pattern}/"
seq = [] return sym_name
for sym in variant.prod: elif isinstance(sym, SymLit):
n = resolve_production(sym) sym_name = gen("lit")
seq.append(n) # hack to make repr have double quotes
res.parse_rules[prod_name].append(" ".join(seq)) res.literals[sym_name] = json.dumps(sym.lit)
return sym_name
# create an environment for checking the equations based on raise Exception(f"unhandled {sym.__class__}")
# the production
env: List[Tuple[str, NodeRef]] = list() seq = []
for sym in variant.prod: for sym in variant.prod:
if isinstance(sym, SymRename): n = resolve_production(sym)
env.append((sym.name, sym.ty)) seq.append(n)
print(env) res.parse_rules[prod_name].append(" ".join(seq))
# for each of the equations, find out what the equation is # create an environment for checking the equations based on
# trying to compute, and generate a thunk corresponding to # the production
# that value. env: List[Tuple[str, NodeRef]] = list()
for eq in variant.equations: for sym in variant.prod:
eq_name = gen(f"eq_{node.name}") if isinstance(sym, SymRename):
thunk_name = gen(f"thunk_{node.name}") env.append((sym.name, sym.ty))
print(env)
print("RHS", eq.rhs, eq.rhs.id)
collect_required_thunks(copy.deepcopy(env), eq.rhs) # for each of the equations, find out what the equation is
# trying to compute, and generate a thunk corresponding to
func_impl = textwrap.dedent( # that value.
f""" for eq in variant.equations:
def {eq_name}() -> None: eq_name = gen(f"eq_{node.name}")
''' {repr(eq)} ''' thunk_name = gen(f"thunk_{node.name}")
pass
def {thunk_name}() -> Thunk[None]: print("RHS", eq.rhs, eq.rhs.id)
return Thunk({eq_name}) collect_required_thunks(copy.deepcopy(env), eq.rhs)
"""
) func_impl = textwrap.dedent(
print(f"```py\n{func_impl}\n```") f"""
res.extra += func_impl def {eq_name}() -> None:
''' {repr(eq)} '''
# this is a "type alias" that connects it to one of the generated pass
# names above def {thunk_name}() -> Thunk[None]:
res.extra += f"{node.name} = {v(node_name)}" return Thunk({eq_name})
"""
return res )
print(f"```py\n{func_impl}\n```")
res.extra += func_impl
# this is a "type alias" that connects it to one of the generated
# names above
res.extra += f"{node.name} = {v(node_name)}"
return res

View file

@ -25,6 +25,8 @@ if __name__ == "__main__":
with open("gen/arith.py", "w") as f: with open("gen/arith.py", "w") as f:
fmt_str = textwrap.dedent( fmt_str = textwrap.dedent(
""" """
# This documented generated by agtest.
__all__ = ["parse"] __all__ = ["parse"]
from typing import Generic, TypeVar, Optional, Callable, Dict, Any from typing import Generic, TypeVar, Optional, Callable, Dict, Any
from lark import Lark, Transformer from lark import Lark, Transformer
@ -43,15 +45,22 @@ if __name__ == "__main__":
self.value = self.func() self.value = self.func()
return self.value return self.value
parser = Lark('''{pd}''', parser='lalr', start={starts}, debug=True) parser = Lark('''{pd}''', parser='lalr', start={starts}, debug=True)
class Trans(Transformer[None]): class Trans(Transformer[None]): {transdef}
pass
{ex} {ex}
def parse(input: str, start: Optional[str] = None) -> Any: def parse(input: str, start: Optional[str] = None) -> Any:
return parser.parse(input, start) tree = parser.parse(input, start)
trans = Trans()
res = trans.transform(tree)
return res
""" """
) )
f.write( f.write(
fmt_str.format(pd=res.parser_data, ex=res.extra, starts=list(res.starts)) fmt_str.format(
pd=res.parser_data,
ex=res.extra,
starts=list(res.starts),
transdef=res.transdef,
)
) )
mod = importlib.import_module("gen.arith") mod = importlib.import_module("gen.arith")