fuck you black

This commit is contained in:
Michael Zhang 2021-09-30 16:08:01 -05:00
parent 61bfd16b81
commit 3b9c12b512
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B

440
aggen.py
View file

@ -1,220 +1,220 @@
from typing import * from typing import *
import textwrap import textwrap
import re import re
import copy import copy
import json import json
from collections import defaultdict from collections import defaultdict
from re import Pattern from re import Pattern
from agast import * from agast import *
global i global i
i = 0 i = 0
class GenResult: class GenResult:
def __init__(self, pd: str = "", ex: str = ""): def __init__(self, pd: str = "", ex: str = ""):
self.literals: Dict[str, str] = dict() self.literals: Dict[str, str] = dict()
self.parse_rules: defaultdict[str, List[str]] = defaultdict(list) self.parse_rules: defaultdict[str, List[str]] = defaultdict(list)
self.starts: Set[str] = set() self.starts: Set[str] = set()
self.extra = ex self.extra = ex
self.trans: List[str] = list() self.trans: List[str] = list()
@property @property
def transdef(self) -> str: def transdef(self) -> str:
s = self.trans s = self.trans
if not s: if not s:
s = ["pass"] s = ["pass"]
return "\n" + "\n".join(map(lambda c: f" {c}", s)) return "\n" + "\n".join(map(lambda c: f" {c}", s))
@property @property
def parser_data(self) -> str: def parser_data(self) -> str:
s = [] s = []
for sym, pat in self.literals.items(): for sym, pat in self.literals.items():
s.append(f"{sym}: {pat}") s.append(f"{sym}: {pat}")
for name, rules in self.parse_rules.items(): for name, rules in self.parse_rules.items():
srules = " | ".join(rules) srules = " | ".join(rules)
s.append(f"{name}: {srules}") s.append(f"{name}: {srules}")
s.append("%import common.WS") s.append("%import common.WS")
s.append("%ignore WS") s.append("%ignore WS")
return "\n".join(s) return "\n".join(s)
def gen(program: List[Decl]) -> GenResult: def gen(program: List[Decl]) -> GenResult:
res = GenResult() res = GenResult()
def gen(prefix: str = "", suffix: str = "") -> str: def gen(prefix: str = "", suffix: str = "") -> str:
global i global i
presan = re.sub("[^0-9a-zA-Z]+", "_", prefix) presan = re.sub("[^0-9a-zA-Z]+", "_", prefix)
sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix) sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix)
i += 1 i += 1
return f"{presan}{i}{sufsan}" return f"{presan}{i}{sufsan}"
def v(name: str) -> str: def v(name: str) -> str:
return f"__ag_{name}" return f"__ag_{name}"
# builtins # builtins
builtins: Dict[str, str] = { builtins: Dict[str, str] = {
"parseInt": "", "parseInt": "",
} }
# collect a list of name -> iface declarations # collect a list of name -> iface declarations
ifaces: Dict[str, Iface] = dict( ifaces: Dict[str, Iface] = dict(
map( map(
lambda c: (c.name, cast(Iface, c)), lambda c: (c.name, cast(Iface, c)),
filter(lambda c: isinstance(c, Iface), program), filter(lambda c: isinstance(c, Iface), program),
) )
) )
# list of node -> iface mappings # list of node -> iface mappings
what_ifaces: Dict[str, Set[str]] = dict() what_ifaces: Dict[str, Set[str]] = dict()
what_fields: Dict[str, Dict[str, str]] = dict() what_fields: Dict[str, Dict[str, str]] = dict()
for node in filter(lambda c: isinstance(c, Node), program): for node in filter(lambda c: isinstance(c, Node), program):
node = cast(Node, node) node = cast(Node, node)
# all_fields = dict() # all_fields = dict()
what_ifaces[node.name] = set(node.ifaces) what_ifaces[node.name] = set(node.ifaces)
this_fields = dict() this_fields = dict()
for iface in node.ifaces: for iface in node.ifaces:
fields = ifaces[iface].fields fields = ifaces[iface].fields
for field in fields: for field in fields:
if field.name in this_fields: if field.name in this_fields:
raise Exception("duplicate field name") raise Exception("duplicate field name")
this_fields[field.name] = field.ty this_fields[field.name] = field.ty
what_fields[node.name] = this_fields what_fields[node.name] = this_fields
print("what_ifaces:", what_ifaces) print("what_ifaces:", what_ifaces)
print("what_fields:", what_fields) print("what_fields:", what_fields)
# a high-level dictionary of productions; this has sub-productions # a high-level dictionary of productions; this has sub-productions
# that should be further expanded at a later step before converting # that should be further expanded at a later step before converting
# into lark code # into lark code
productions_hi: Dict[str, Union[str, List[str]]] = dict() productions_hi: Dict[str, Union[str, List[str]]] = dict()
# TODO: this should probably not be inlined here, but i'll move it # TODO: this should probably not be inlined here, but i'll move it
# out once i get more info into the 'env' # out once i get more info into the 'env'
def collect_required_thunks( def collect_required_thunks(
env: List[Tuple[str, NodeRef]], expr: Expr env: List[Tuple[str, NodeRef]], expr: Expr
) -> Dict[str, str]: ) -> Dict[str, str]:
names = dict(env) names = dict(env)
print(f"collect_required_thunks({expr})", expr.__class__) print(f"collect_required_thunks({expr})", expr.__class__)
if isinstance(expr, ExprDot): if isinstance(expr, ExprDot):
return collect_required_thunks(env, expr.left) return collect_required_thunks(env, expr.left)
elif isinstance(expr, ExprMul): elif isinstance(expr, ExprMul):
a = collect_required_thunks(env, expr.left) a = collect_required_thunks(env, expr.left)
b = collect_required_thunks(env, expr.right) b = collect_required_thunks(env, expr.right)
a.update(b) a.update(b)
return a return a
elif isinstance(expr, ExprAdd): elif isinstance(expr, ExprAdd):
a = collect_required_thunks(env, expr.left) a = collect_required_thunks(env, expr.left)
b = collect_required_thunks(env, expr.right) b = collect_required_thunks(env, expr.right)
a.update(b) a.update(b)
return a return a
elif isinstance(expr, ExprCall): elif isinstance(expr, ExprCall):
return collect_required_thunks(env, expr.func) return collect_required_thunks(env, expr.func)
elif isinstance(expr, ExprName): elif isinstance(expr, ExprName):
if expr.name not in names and expr.name not in builtins: if expr.name not in names and expr.name not in builtins:
raise Exception(f"unbound name '{expr.name}'") raise Exception(f"unbound name '{expr.name}'")
return dict() return dict()
raise Exception(f"unhandled {expr.__class__}") raise Exception(f"unhandled {expr.__class__}")
node_map = dict( node_map = dict(
map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program)) map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program))
) )
node_name_map = dict( node_name_map = dict(
map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items()) map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items())
) )
for node in node_map.values(): for node in node_map.values():
node = cast(Node, node) node = cast(Node, node)
node_name_lower = node.name.lower() node_name_lower = node.name.lower()
node_name = node_name_map[node.name] node_name = node_name_map[node.name]
res.parse_rules[f"?{node_name_lower}"].append(node_name) res.parse_rules[f"?{node_name_lower}"].append(node_name)
res.starts.add(node_name_lower) res.starts.add(node_name_lower)
class_decl = textwrap.dedent( class_decl = textwrap.dedent(
f""" f"""
class {v(node_name)}: pass class {v(node_name)}: pass
""" """
) )
res.extra += class_decl res.extra += class_decl
print(node.name, node.ifaces) print(node.name, node.ifaces)
for variant in node.variants: for variant in node.variants:
v_class_name = gen(f"{node_name}_var") v_class_name = gen(f"{node_name}_var")
class_decl = textwrap.dedent( class_decl = textwrap.dedent(
f""" f"""
class {v(v_class_name)}({v(node_name)}): class {v(v_class_name)}({v(node_name)}):
''' ''' ''' '''
pass pass
""" """
) )
res.extra += class_decl res.extra += class_decl
prod_name = gen(node_name + "_") prod_name = gen(node_name + "_")
res.parse_rules[node_name].append(prod_name) res.parse_rules[node_name].append(prod_name)
print("PRODUCTION", prod_name, variant.prod) print("PRODUCTION", prod_name, variant.prod)
# resolving a production just means checking to make sure it's a # resolving a production just means checking to make sure it's a
# type that exists or it's a regex # type that exists or it's a regex
def resolve_production(sym: Sym) -> str: def resolve_production(sym: Sym) -> str:
print(f"resolve_production({sym})") print(f"resolve_production({sym})")
if isinstance(sym, SymRename): if isinstance(sym, SymRename):
if isinstance(sym.ty, NodeRefByName): if isinstance(sym.ty, NodeRefByName):
if sym.ty.name in node_name_map: if sym.ty.name in node_name_map:
return node_name_map[sym.ty.name] return node_name_map[sym.ty.name]
else: else:
raise Exception( raise Exception(
f"unresolved name {sym.ty.name} in production" f"unresolved name {sym.ty.name} in production"
) )
elif isinstance(sym.ty, NodeRegex): elif isinstance(sym.ty, NodeRegex):
sym_name = gen("sym") sym_name = gen("sym")
res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" res.literals[sym_name] = f"/{sym.ty.pat.pattern}/"
return sym_name return sym_name
elif isinstance(sym, SymLit): elif isinstance(sym, SymLit):
sym_name = gen("lit") sym_name = gen("lit")
# hack to make repr have double quotes # hack to make repr have double quotes
res.literals[sym_name] = json.dumps(sym.lit) res.literals[sym_name] = json.dumps(sym.lit)
return sym_name return sym_name
raise Exception(f"unhandled {sym.__class__}") raise Exception(f"unhandled {sym.__class__}")
seq = [] seq = []
for sym in variant.prod: for sym in variant.prod:
n = resolve_production(sym) n = resolve_production(sym)
seq.append(n) seq.append(n)
res.parse_rules[prod_name].append(" ".join(seq)) res.parse_rules[prod_name].append(" ".join(seq))
# create an environment for checking the equations based on # create an environment for checking the equations based on
# the production # the production
env: List[Tuple[str, NodeRef]] = list() env: List[Tuple[str, NodeRef]] = list()
for sym in variant.prod: for sym in variant.prod:
if isinstance(sym, SymRename): if isinstance(sym, SymRename):
env.append((sym.name, sym.ty)) env.append((sym.name, sym.ty))
print(env) print(env)
# for each of the equations, find out what the equation is # for each of the equations, find out what the equation is
# trying to compute, and generate a thunk corresponding to # trying to compute, and generate a thunk corresponding to
# that value. # that value.
for eq in variant.equations: for eq in variant.equations:
eq_name = gen(f"eq_{node.name}") eq_name = gen(f"eq_{node.name}")
thunk_name = gen(f"thunk_{node.name}") thunk_name = gen(f"thunk_{node.name}")
print("RHS", eq.rhs, eq.rhs.id) print("RHS", eq.rhs, eq.rhs.id)
collect_required_thunks(copy.deepcopy(env), eq.rhs) collect_required_thunks(copy.deepcopy(env), eq.rhs)
func_impl = textwrap.dedent( func_impl = textwrap.dedent(
f""" f"""
def {eq_name}() -> None: def {eq_name}() -> None:
''' {repr(eq)} ''' ''' {repr(eq)} '''
pass pass
def {thunk_name}() -> Thunk[None]: def {thunk_name}() -> Thunk[None]:
return Thunk({eq_name}) return Thunk({eq_name})
""" """
) )
print(f"```py\n{func_impl}\n```") print(f"```py\n{func_impl}\n```")
res.extra += func_impl res.extra += func_impl
# this is a "type alias" that connects it to one of the generated # this is a "type alias" that connects it to one of the generated
# names above # names above
res.extra += f"{node.name} = {v(node_name)}" res.extra += f"{node.name} = {v(node_name)}"
return res return res