agtest/aggen.py

223 lines
7.5 KiB
Python

from typing import *
import textwrap
import re
import copy
import json
from collections import defaultdict
from re import Pattern
from agast import *
global i
i = 0
def gensym(prefix: str = "", suffix: str = "") -> str:
global i
presan = re.sub("[^0-9a-zA-Z]+", "_", prefix)
sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix)
i += 1
return f"{presan}{i}{sufsan}"
class NodeDesc:
def __init__(self, node: Node):
self.node = node
self.name = node.name
self.nonterminal = gensym(node.name.lower())
class GenResult:
def __init__(self, pd: str = "", ex: str = ""):
self.literals: Dict[str, str] = dict()
self.parse_rules: defaultdict[str, List[str]] = defaultdict(list)
self.starts: Set[str] = set()
self.extra = ex
self.trans: List[str] = list()
@property
def transdef(self) -> str:
s = self.trans
if not s:
s = ["pass"]
return "\n" + "\n".join(map(lambda c: f" {c}", s))
@property
def parser_data(self) -> str:
s = []
for sym, pat in self.literals.items():
s.append(f"{sym}: {pat}")
for name, rules in self.parse_rules.items():
srules = " | ".join(rules)
s.append(f"{name}: {srules}")
s.append("%import common.WS")
s.append("%ignore WS")
return "\n".join(s)
def gen(program: List[Decl]) -> GenResult:
res = GenResult()
def v(name: str) -> str:
return f"__ag_{name}"
# builtins
builtins: Dict[str, str] = {
"parseInt": "",
}
# collect a list of name -> iface declarations
ifaces: Dict[str, Iface] = dict(
map(
lambda c: (c.name, cast(Iface, c)),
filter(lambda c: isinstance(c, Iface), program),
)
)
# list of node -> iface mappings
what_ifaces: Dict[str, Set[str]] = dict()
what_fields: Dict[str, Dict[str, str]] = dict()
for node in filter(lambda c: isinstance(c, Node), program):
node = cast(Node, node)
# all_fields = dict()
what_ifaces[node.name] = set(node.ifaces)
this_fields = dict()
for iface in node.ifaces:
fields = ifaces[iface].fields
for field in fields:
if field.name in this_fields:
raise Exception("duplicate field name")
this_fields[field.name] = field.ty
what_fields[node.name] = this_fields
print("what_ifaces:", what_ifaces)
print("what_fields:", what_fields)
# a high-level dictionary of productions; this has sub-productions
# that should be further expanded at a later step before converting
# into lark code
productions_hi: Dict[str, Union[str, List[str]]] = dict()
# TODO: this should probably not be inlined here, but i'll move it
# out once i get more info into the 'env'
def collect_required_thunks(
env: List[Tuple[str, NodeRef]], expr: Expr
) -> Dict[str, str]:
names = dict(env)
print(f"collect_required_thunks({expr})", expr.__class__)
if isinstance(expr, ExprDot):
return collect_required_thunks(env, expr.left)
elif isinstance(expr, ExprMul):
a = collect_required_thunks(env, expr.left)
b = collect_required_thunks(env, expr.right)
a.update(b)
return a
elif isinstance(expr, ExprAdd):
a = collect_required_thunks(env, expr.left)
b = collect_required_thunks(env, expr.right)
a.update(b)
return a
elif isinstance(expr, ExprCall):
return collect_required_thunks(env, expr.func)
elif isinstance(expr, ExprName):
if expr.name not in names and expr.name not in builtins:
raise Exception(f"unbound name '{expr.name}'")
return dict()
raise Exception(f"unhandled {expr.__class__}")
node_map: Dict[str, NodeDesc] = dict(
map(
lambda n: (n.name, NodeDesc(cast(Node, n))),
filter(lambda c: isinstance(c, Node), program),
)
)
for node_desc in node_map.values():
assert isinstance(node_desc, NodeDesc)
res.starts.add(node_desc.name.lower())
res.parse_rules[f"?{node_desc.name.lower()}"].append(node_desc.nonterminal)
class_decl = textwrap.dedent(
f"""
class {node_desc.nonterminal}: pass
"""
)
res.extra += class_decl
print(node_desc.name, node_desc.node.ifaces)
for variant in node_desc.node.variants:
v_class_name = gensym(f"{node_desc.nonterminal}_var")
class_decl = textwrap.dedent(
f"""
class {v_class_name}({node_desc.nonterminal}):
''' '''
pass
"""
)
res.extra += class_decl
prod_name = gensym(node_desc.nonterminal + "_")
res.parse_rules[node_desc.nonterminal].append(prod_name)
print("PRODUCTION", prod_name, variant.prod)
# resolving a production just means checking to make sure it's a
# type that exists or it's a regex
def resolve_production(sym: Sym) -> str:
print(f"resolve_production({sym})")
if isinstance(sym, SymRename):
if isinstance(sym.ty, NodeRefByName):
if sym.ty.name in node_map:
return node_map[sym.ty.name].nonterminal
else:
raise Exception(
f"unresolved name {sym.ty.name} in production"
)
elif isinstance(sym.ty, NodeRegex):
sym_name = gensym("sym")
res.literals[sym_name] = f"/{sym.ty.pat.pattern}/"
return sym_name
elif isinstance(sym, SymLit):
sym_name = gensym("lit")
# hack to make repr have double quotes
res.literals[sym_name] = json.dumps(sym.lit)
return sym_name
raise Exception(f"unhandled {sym.__class__}")
seq = []
for sym in variant.prod:
n = resolve_production(sym)
seq.append(n)
res.parse_rules[prod_name].append(" ".join(seq))
# create an environment for checking the equations based on
# the production
env: List[Tuple[str, NodeRef]] = list()
for sym in variant.prod:
if isinstance(sym, SymRename):
env.append((sym.name, sym.ty))
print(env)
# for each of the equations, find out what the equation is
# trying to compute, and generate a thunk corresponding to
# that value.
for eq in variant.equations:
eq_name = gensym(f"eq_{node_desc.name}")
thunk_name = gensym(f"thunk_{node_desc.name}")
print("RHS", eq.rhs, eq.rhs.id)
collect_required_thunks(copy.deepcopy(env), eq.rhs)
func_impl = textwrap.dedent(
f"""
def {eq_name}() -> None:
''' {repr(eq)} '''
pass
def {thunk_name}() -> Thunk[None]:
return Thunk({eq_name})
"""
)
print(f"```py\n{func_impl}\n```")
res.extra += func_impl
return res