wish python had lexical scoping

This commit is contained in:
Michael Zhang 2021-09-30 16:37:43 -05:00
parent 3b9c12b512
commit 0a5738e8cd
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B
5 changed files with 40 additions and 36 deletions

View file

@ -3,6 +3,7 @@ watch:
fmt: fmt:
pipenv run black . pipenv run black .
dos2unix *
run: run:
mypy *.py mypy *.py

View file

@ -12,6 +12,21 @@ global i
i = 0 i = 0
def gensym(prefix: str = "", suffix: str = "") -> str:
global i
presan = re.sub("[^0-9a-zA-Z]+", "_", prefix)
sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix)
i += 1
return f"{presan}{i}{sufsan}"
class NodeDesc:
def __init__(self, node: Node):
self.node = node
self.name = node.name
self.nonterminal = gensym(node.name.lower())
class GenResult: class GenResult:
def __init__(self, pd: str = "", ex: str = ""): def __init__(self, pd: str = "", ex: str = ""):
self.literals: Dict[str, str] = dict() self.literals: Dict[str, str] = dict()
@ -43,13 +58,6 @@ class GenResult:
def gen(program: List[Decl]) -> GenResult: def gen(program: List[Decl]) -> GenResult:
res = GenResult() res = GenResult()
def gen(prefix: str = "", suffix: str = "") -> str:
global i
presan = re.sub("[^0-9a-zA-Z]+", "_", prefix)
sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix)
i += 1
return f"{presan}{i}{sufsan}"
def v(name: str) -> str: def v(name: str) -> str:
return f"__ag_{name}" return f"__ag_{name}"
@ -116,42 +124,41 @@ def gen(program: List[Decl]) -> GenResult:
return dict() return dict()
raise Exception(f"unhandled {expr.__class__}") raise Exception(f"unhandled {expr.__class__}")
node_map = dict( node_map: Dict[str, NodeDesc] = dict(
map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program)) map(
lambda n: (n.name, NodeDesc(cast(Node, n))),
filter(lambda c: isinstance(c, Node), program),
) )
node_name_map = dict(
map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items())
) )
for node in node_map.values(): for node_desc in node_map.values():
node = cast(Node, node) assert isinstance(node_desc, NodeDesc)
node_name_lower = node.name.lower()
node_name = node_name_map[node.name] res.starts.add(node_desc.name.lower())
res.parse_rules[f"?{node_name_lower}"].append(node_name) res.parse_rules[f"?{node_desc.name.lower()}"].append(node_desc.nonterminal)
res.starts.add(node_name_lower)
class_decl = textwrap.dedent( class_decl = textwrap.dedent(
f""" f"""
class {v(node_name)}: pass class {node_desc.nonterminal}: pass
""" """
) )
res.extra += class_decl res.extra += class_decl
print(node.name, node.ifaces) print(node_desc.name, node_desc.node.ifaces)
for variant in node.variants: for variant in node_desc.node.variants:
v_class_name = gen(f"{node_name}_var") v_class_name = gensym(f"{node_desc.nonterminal}_var")
class_decl = textwrap.dedent( class_decl = textwrap.dedent(
f""" f"""
class {v(v_class_name)}({v(node_name)}): class {v_class_name}({node_desc.nonterminal}):
''' ''' ''' '''
pass pass
""" """
) )
res.extra += class_decl res.extra += class_decl
prod_name = gen(node_name + "_") prod_name = gensym(node_desc.nonterminal + "_")
res.parse_rules[node_name].append(prod_name) res.parse_rules[node_desc.nonterminal].append(prod_name)
print("PRODUCTION", prod_name, variant.prod) print("PRODUCTION", prod_name, variant.prod)
# resolving a production just means checking to make sure it's a # resolving a production just means checking to make sure it's a
@ -160,18 +167,18 @@ def gen(program: List[Decl]) -> GenResult:
print(f"resolve_production({sym})") print(f"resolve_production({sym})")
if isinstance(sym, SymRename): if isinstance(sym, SymRename):
if isinstance(sym.ty, NodeRefByName): if isinstance(sym.ty, NodeRefByName):
if sym.ty.name in node_name_map: if sym.ty.name in node_map:
return node_name_map[sym.ty.name] return node_map[sym.ty.name].nonterminal
else: else:
raise Exception( raise Exception(
f"unresolved name {sym.ty.name} in production" f"unresolved name {sym.ty.name} in production"
) )
elif isinstance(sym.ty, NodeRegex): elif isinstance(sym.ty, NodeRegex):
sym_name = gen("sym") sym_name = gensym("sym")
res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" res.literals[sym_name] = f"/{sym.ty.pat.pattern}/"
return sym_name return sym_name
elif isinstance(sym, SymLit): elif isinstance(sym, SymLit):
sym_name = gen("lit") sym_name = gensym("lit")
# hack to make repr have double quotes # hack to make repr have double quotes
res.literals[sym_name] = json.dumps(sym.lit) res.literals[sym_name] = json.dumps(sym.lit)
return sym_name return sym_name
@ -195,8 +202,8 @@ def gen(program: List[Decl]) -> GenResult:
# trying to compute, and generate a thunk corresponding to # trying to compute, and generate a thunk corresponding to
# that value. # that value.
for eq in variant.equations: for eq in variant.equations:
eq_name = gen(f"eq_{node.name}") eq_name = gensym(f"eq_{node_desc.name}")
thunk_name = gen(f"thunk_{node.name}") thunk_name = gensym(f"thunk_{node_desc.name}")
print("RHS", eq.rhs, eq.rhs.id) print("RHS", eq.rhs, eq.rhs.id)
collect_required_thunks(copy.deepcopy(env), eq.rhs) collect_required_thunks(copy.deepcopy(env), eq.rhs)
@ -213,8 +220,4 @@ def gen(program: List[Decl]) -> GenResult:
print(f"```py\n{func_impl}\n```") print(f"```py\n{func_impl}\n```")
res.extra += func_impl res.extra += func_impl
# this is a "type alias" that connects it to one of the generated
# names above
res.extra += f"{node.name} = {v(node_name)}"
return res return res

View file

@ -9,7 +9,7 @@ from aggen import *
p = Lark(open("grammar.lark").read(), start="program", parser="lalr") p = Lark(open("grammar.lark").read(), start="program", parser="lalr")
if __name__ == "__main__": if __name__ == "__main__":
with open("arith.ag") as f: with open("test/arith.ag") as f:
data = f.read() data = f.read()
cst = p.parse(data) cst = p.parse(data)

View file