This commit is contained in:
Michael Zhang 2021-10-01 00:15:23 -05:00
parent a1bcb8946a
commit 0fea420749
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B
6 changed files with 139 additions and 154 deletions

4
Pipfile.lock generated
View file

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "3c94591e63af431c312ac30f7f7e19d0f406f796ce447837cc5edd363ff2764d" "sha256": "1a84fb4e4ea0b66ca791e226c4da40bd974bc5eceec5868bb5752aad478c55e8"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@ -259,7 +259,7 @@
"sha256:94078db9184491e15bce0a56d9186e0aec95f16ac20b12d00e06d4e36f1058a6", "sha256:94078db9184491e15bce0a56d9186e0aec95f16ac20b12d00e06d4e36f1058a6",
"sha256:98a535c62a4fcfcc362528592f69b26f7caec587d32cd55688db580be0287ae0" "sha256:98a535c62a4fcfcc362528592f69b26f7caec587d32cd55688db580be0287ae0"
], ],
"markers": "python_version >= '3.6'", "index": "pypi",
"version": "==4.2.0" "version": "==4.2.0"
}, },
"sphinxcontrib-applehelp": { "sphinxcontrib-applehelp": {

View file

@ -34,11 +34,12 @@ class ParseEquation:
class GenResult: class GenResult:
def __init__(self, pd: str = "", ex: str = ""): def __init__(self, program: List[Decl]):
self.program = program
self.literals: Dict[str, str] = dict() self.literals: Dict[str, str] = dict()
self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list) self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list)
self.starts: Set[str] = set() self.starts: Set[str] = set()
self.extra = ex self.extra: str = ""
self.nonterminal_map: Dict[str, str] = dict() self.nonterminal_map: Dict[str, str] = dict()
@property @property
@ -77,160 +78,157 @@ class GenResult:
return "\n".join(s) return "\n".join(s)
def gen(program: List[Decl]) -> GenResult: def gen(self) -> None:
res = GenResult() def v(name: str) -> str:
return f"__ag_{name}"
def v(name: str) -> str: # builtins
return f"__ag_{name}" builtins: Dict[str, str] = {
"parseInt": "",
}
# builtins # collect a list of name -> iface declarations
builtins: Dict[str, str] = { ifaces: Dict[str, Iface] = dict(
"parseInt": "", map(
} lambda c: (c.name, cast(Iface, c)),
filter(lambda c: isinstance(c, Iface), self.program),
# collect a list of name -> iface declarations )
ifaces: Dict[str, Iface] = dict(
map(
lambda c: (c.name, cast(Iface, c)),
filter(lambda c: isinstance(c, Iface), program),
) )
)
# list of node -> iface mappings # list of node -> iface mappings
what_ifaces: Dict[str, Set[str]] = dict() what_ifaces: Dict[str, Set[str]] = dict()
what_fields: Dict[str, Dict[str, str]] = dict() what_fields: Dict[str, Dict[str, str]] = dict()
for node in filter(lambda c: isinstance(c, Node), program): for node in filter(lambda c: isinstance(c, Node), self.program):
node = cast(Node, node) node = cast(Node, node)
# all_fields = dict() # all_fields = dict()
what_ifaces[node.name] = set(node.ifaces) what_ifaces[node.name] = set(node.ifaces)
this_fields = dict() this_fields = dict()
for iface in node.ifaces: for iface in node.ifaces:
fields = ifaces[iface].fields fields = ifaces[iface].fields
for field in fields: for field in fields:
if field.name in this_fields: if field.name in this_fields:
raise Exception("duplicate field name") raise Exception("duplicate field name")
this_fields[field.name] = field.ty this_fields[field.name] = field.ty
what_fields[node.name] = this_fields what_fields[node.name] = this_fields
print("what_ifaces:", what_ifaces) print("what_ifaces:", what_ifaces)
print("what_fields:", what_fields) print("what_fields:", what_fields)
# a high-level dictionary of productions; this has sub-productions that # a high-level dictionary of productions; this has sub-productions that
# should be further expanded at a later step before converting into lark # should be further expanded at a later step before converting into lark
# code # code
productions_hi: Dict[str, Union[str, List[str]]] = dict() productions_hi: Dict[str, Union[str, List[str]]] = dict()
# TODO: this should probably not be inlined here, but i'll move it out once # TODO: this should probably not be inlined here, but i'll move it out once
# i get more info into the 'env' # i get more info into the 'env'
def collect_required_thunks( def collect_required_thunks(
env: List[Tuple[str, NodeRef]], expr: Expr env: List[Tuple[str, NodeRef]], expr: Expr
) -> Dict[str, str]: ) -> Dict[str, str]:
names = dict(env) names = dict(env)
print(f"collect_required_thunks({expr})", expr.__class__) print(f"collect_required_thunks({expr})", expr.__class__)
if isinstance(expr, ExprDot): if isinstance(expr, ExprDot):
return collect_required_thunks(env, expr.left) return collect_required_thunks(env, expr.left)
elif isinstance(expr, ExprMul): elif isinstance(expr, ExprMul):
a = collect_required_thunks(env, expr.left) a = collect_required_thunks(env, expr.left)
b = collect_required_thunks(env, expr.right) b = collect_required_thunks(env, expr.right)
a.update(b) a.update(b)
return a return a
elif isinstance(expr, ExprAdd): elif isinstance(expr, ExprAdd):
a = collect_required_thunks(env, expr.left) a = collect_required_thunks(env, expr.left)
b = collect_required_thunks(env, expr.right) b = collect_required_thunks(env, expr.right)
a.update(b) a.update(b)
return a return a
elif isinstance(expr, ExprCall): elif isinstance(expr, ExprCall):
return collect_required_thunks(env, expr.func) return collect_required_thunks(env, expr.func)
elif isinstance(expr, ExprName): elif isinstance(expr, ExprName):
if expr.name not in names and expr.name not in builtins: if expr.name not in names and expr.name not in builtins:
raise Exception(f"unbound name '{expr.name}'") raise Exception(f"unbound name '{expr.name}'")
return dict() return dict()
raise Exception(f"unhandled {expr.__class__}") raise Exception(f"unhandled {expr.__class__}")
node_map: Dict[str, NodeDesc] = dict() node_map: Dict[str, NodeDesc] = dict()
for _node in filter(lambda c: isinstance(c, Node), program): for _node in filter(lambda c: isinstance(c, Node), self.program):
nd = NodeDesc(cast(Node, _node)) nd = NodeDesc(cast(Node, _node))
node_map[_node.name] = nd node_map[_node.name] = nd
res.nonterminal_map[nd.name] = nd.nonterminal self.nonterminal_map[nd.name] = nd.nonterminal
for node_desc in node_map.values(): for node_desc in node_map.values():
assert isinstance(node_desc, NodeDesc) assert isinstance(node_desc, NodeDesc)
res.starts.add(node_desc.nonterminal) self.starts.add(node_desc.nonterminal)
class_fields = [] class_fields = []
for field_name, field_ty in what_fields[node_desc.name].items(): for field_name, field_ty in what_fields[node_desc.name].items():
class_fields.append(f"{field_name}: Thunk[{field_ty}]") class_fields.append(f"{field_name}: Thunk[{field_ty}]")
g = textwrap.indent("\n".join(class_fields), " ") g = textwrap.indent("\n".join(class_fields), " ")
class_decl = textwrap.dedent(
f"""
class {node_desc.nonterminal}:
{g}
pass
"""
)
res.extra += class_decl
print(node_desc.name, node_desc.node.ifaces)
for variant in node_desc.node.variants:
v_class_name = gensym(f"{node_desc.nonterminal}_var")
class_decl = textwrap.dedent( class_decl = textwrap.dedent(
f""" f"""
class {v_class_name}({node_desc.nonterminal}): pass class {node_desc.nonterminal}:
{g}
pass
""" """
) )
res.extra += class_decl self.extra += class_decl
prod_name = gensym(node_desc.nonterminal + "_") print(node_desc.name, node_desc.node.ifaces)
print("PRODUCTION", prod_name, variant.prod)
# resolving a production just means checking to make sure it's a for variant in node_desc.node.variants:
# type that exists or it's a regex v_class_name = gensym(f"{node_desc.nonterminal}_var")
def resolve_production(sym: Sym) -> str: class_decl = textwrap.dedent(
print(f"resolve_production({sym})") f"""
if isinstance(sym, SymRename): class {v_class_name}({node_desc.nonterminal}): pass
if isinstance(sym.ty, NodeRefByName): """
if sym.ty.name in node_map: )
return node_map[sym.ty.name].nonterminal self.extra += class_decl
else:
raise Exception( prod_name = gensym(node_desc.nonterminal + "_")
f"unresolved name {sym.ty.name} in production" print("PRODUCTION", prod_name, variant.prod)
)
elif isinstance(sym.ty, NodeRegex): # resolving a production just means checking to make sure it's a
sym_name = gensym("sym") # type that exists or it's a regex
res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" def resolve_production(sym: Sym) -> str:
print(f"resolve_production({sym})")
if isinstance(sym, SymRename):
if isinstance(sym.ty, NodeRefByName):
if sym.ty.name in node_map:
return node_map[sym.ty.name].nonterminal
else:
raise Exception(
f"unresolved name {sym.ty.name} in production"
)
elif isinstance(sym.ty, NodeRegex):
sym_name = gensym("sym")
self.literals[sym_name] = f"/{sym.ty.pat.pattern}/"
return sym_name
elif isinstance(sym, SymLit):
sym_name = gensym("lit")
# hack to make repr have double quotes
self.literals[sym_name] = json.dumps(sym.lit)
return sym_name return sym_name
elif isinstance(sym, SymLit): raise Exception(f"unhandled {sym.__class__}")
sym_name = gensym("lit")
# hack to make repr have double quotes
res.literals[sym_name] = json.dumps(sym.lit)
return sym_name
raise Exception(f"unhandled {sym.__class__}")
seq = [] seq = []
for sym in variant.prod: for sym in variant.prod:
n = resolve_production(sym) n = resolve_production(sym)
seq.append(n) seq.append(n)
res.parse_rules[node_desc.nonterminal].append( self.parse_rules[node_desc.nonterminal].append(
ParseEquation(prod_name, seq, v_class_name) ParseEquation(prod_name, seq, v_class_name)
) )
# create an environment for checking the equations based on the # create an environment for checking the equations based on the
# production # production
env: List[Tuple[str, NodeRef]] = list() env: List[Tuple[str, NodeRef]] = list()
for sym in variant.prod: for sym in variant.prod:
if isinstance(sym, SymRename): if isinstance(sym, SymRename):
env.append((sym.name, sym.ty)) env.append((sym.name, sym.ty))
print(env) print(env)
# for each of the equations, find out what the equation is trying # for each of the equations, find out what the equation is trying
# to compute, and generate a thunk corresponding to that value. # to compute, and generate a thunk corresponding to that value.
for eq in variant.equations: for eq in variant.equations:
eq_name = gensym(f"eq_{node_desc.name}") eq_name = gensym(f"eq_{node_desc.name}")
thunk_name = gensym(f"thunk_{node_desc.name}") thunk_name = gensym(f"thunk_{node_desc.name}")
print("RHS", eq.rhs, eq.rhs.id) print("RHS", eq.rhs, eq.rhs.id)
collect_required_thunks(copy.deepcopy(env), eq.rhs) collect_required_thunks(copy.deepcopy(env), eq.rhs)
return res

View file

@ -4,6 +4,7 @@ from os import path
import json import json
import importlib import importlib
from lark import Lark from lark import Lark
import sys
from agtest.ast import * from agtest.ast import *
from agtest.gen import * from agtest.gen import *

0
docs/_static/.gitkeep vendored Normal file
View file

View file

@ -4,14 +4,6 @@ agtest package
Submodules Submodules
---------- ----------
agtest.ast module
-----------------
.. automodule:: agtest.ast
:members:
:undoc-members:
:show-inheritance:
agtest.gen module agtest.gen module
----------------- -----------------
@ -20,14 +12,6 @@ agtest.gen module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
agtest.main module
------------------
.. automodule:: agtest.main
:members:
:undoc-members:
:show-inheritance:
Module contents Module contents
--------------- ---------------

View file

@ -10,6 +10,8 @@ Welcome to agtest's documentation!
:maxdepth: 2 :maxdepth: 2
:caption: Contents: :caption: Contents:
agtest
Indices and tables Indices and tables