This commit is contained in:
Michael Zhang 2021-10-01 13:38:40 -05:00
parent cf67e78e9a
commit 310f9c2ade
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B
4 changed files with 95 additions and 85 deletions

View file

@ -9,8 +9,10 @@
flake-utils.lib.eachDefaultSystem (system: flake-utils.lib.eachDefaultSystem (system:
let let
pkgs = nixpkgs.legacyPackages.${system}; pkgs = nixpkgs.legacyPackages.${system};
pythonPkgs = pkgs.python39Packages;
myPkgs = rec { myPkgs = rec {
agtest = pkgs.python39Packages.callPackage ./. {}; agtest = pythonPkgs.callPackage ./. {};
}; };
in in
{ {

View file

@ -3,7 +3,7 @@ name = click
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
agt = agtest.driver:run agtest = agtest.driver:run
[options] [options]
packages = find: packages = find:

View file

@ -1,6 +1,7 @@
import importlib import importlib
import json import json
import os import os
import sys
from io import StringIO from io import StringIO
from os import path from os import path
@ -17,8 +18,9 @@ runtime_path = path.join(src_dir, "runtime.tmpl.py")
p = lark.Lark(open(grammar_path).read(), start="program", parser="lalr") p = lark.Lark(open(grammar_path).read(), start="program", parser="lalr")
@click.command() @click.command()
@click.option("--show-only", is_flag=True)
@click.argument("input", type=click.File("r")) @click.argument("input", type=click.File("r"))
def run(input): def run(input, show_only):
data = input.read() data = input.read()
input.close() input.close()
@ -44,6 +46,10 @@ def run(input):
) )
) )
if show_only:
print(s.getvalue())
sys.exit(0)
print("Dropping you in a Python shell...") print("Dropping you in a Python shell...")
print("Call parse(str) to parse something.") print("Call parse(str) to parse something.")

View file

@ -42,6 +42,15 @@ class GenResult:
self.extra: str = "" self.extra: str = ""
self.nonterminal_map: Dict[str, str] = dict() self.nonterminal_map: Dict[str, str] = dict()
self.ifaces: Dict[str, Iface] = dict()
self.what_ifaces: Dict[str, Set[str]] = dict()
self.what_fields: Dict[str, Dict[str, str]] = dict()
self.node_map: Dict[str, NodeDesc] = dict()
self.builtins: Dict[str, str] = {
"parseInt": "",
}
@property @property
def trans_def(self) -> str: def trans_def(self) -> str:
s = [] s = []
@ -78,83 +87,101 @@ class GenResult:
return "\n".join(s) return "\n".join(s)
def build(self) -> None: def _collect_ifaces(self) -> None:
def v(name: str) -> str: """ collect a list of name -> iface declarations"""
return f"__ag_{name}" self.ifaces = dict(
# builtins
builtins: Dict[str, str] = {
"parseInt": "",
}
# collect a list of name -> iface declarations
ifaces: Dict[str, Iface] = dict(
map( map(
lambda c: (c.name, cast(Iface, c)), lambda c: (c.name, cast(Iface, c)),
filter(lambda c: isinstance(c, Iface), self.program), filter(lambda c: isinstance(c, Iface), self.program),
) )
) )
# list of node -> iface mappings def _create_iface_mappings(self) -> None:
what_ifaces: Dict[str, Set[str]] = dict() """ list of node -> iface mappings """
what_fields: Dict[str, Dict[str, str]] = dict() self.what_ifaces = dict()
self.what_fields = dict()
for node in filter(lambda c: isinstance(c, Node), self.program): for node in filter(lambda c: isinstance(c, Node), self.program):
node = cast(Node, node) node = cast(Node, node)
# all_fields = dict() # all_fields = dict()
what_ifaces[node.name] = set(node.ifaces) self.what_ifaces[node.name] = set(node.ifaces)
this_fields = dict() this_fields = dict()
for iface in node.ifaces: for iface in node.ifaces:
fields = ifaces[iface].fields fields = self.ifaces[iface].fields
for field in fields: for field in fields:
if field.name in this_fields: if field.name in this_fields:
raise Exception("duplicate field name") raise Exception("duplicate field name")
this_fields[field.name] = field.ty this_fields[field.name] = field.ty
what_fields[node.name] = this_fields self.what_fields[node.name] = this_fields
def _collect_required_thunks(
self, env: List[Tuple[str, NodeRef]], expr: Expr
) -> Dict[str, str]:
names = dict(env)
if isinstance(expr, ExprDot):
return self._collect_required_thunks(env, expr.left)
elif isinstance(expr, ExprMul):
a = self._collect_required_thunks(env, expr.left)
b = self._collect_required_thunks(env, expr.right)
a.update(b)
return a
elif isinstance(expr, ExprAdd):
a = self._collect_required_thunks(env, expr.left)
b = self._collect_required_thunks(env, expr.right)
a.update(b)
return a
elif isinstance(expr, ExprCall):
return self._collect_required_thunks(env, expr.func)
elif isinstance(expr, ExprName):
if expr.name not in names and expr.name not in self.builtins:
raise Exception(f"unbound name '{expr.name}'")
return dict()
raise Exception(f"unhandled {expr.__class__}")
def _resolve_production(self, sym: Sym) -> str:
""" resolving a production just means checking to make sure it's a type that exists or it's a regex"""
if isinstance(sym, SymRename):
if isinstance(sym.ty, NodeRefByName):
if sym.ty.name in self.node_map:
return self.node_map[sym.ty.name].nonterminal
else:
raise Exception(
f"unresolved name {sym.ty.name} in production"
)
elif isinstance(sym.ty, NodeRegex):
sym_name = gensym("sym")
self.literals[sym_name] = f"/{sym.ty.pat.pattern}/"
return sym_name
elif isinstance(sym, SymLit):
sym_name = gensym("lit")
# hack to make repr have double quotes
self.literals[sym_name] = json.dumps(sym.lit)
return sym_name
raise Exception(f"unhandled {sym.__class__}")
def _build_node_map(self) -> None:
for _node in filter(lambda c: isinstance(c, Node), self.program):
nd = NodeDesc(cast(Node, _node))
self.node_map[_node.name] = nd
self.nonterminal_map[nd.name] = nd.nonterminal
def build(self) -> None:
def v(name: str) -> str:
return f"__ag_{name}"
self._collect_ifaces()
self._create_iface_mappings()
# a high-level dictionary of productions; this has sub-productions that # a high-level dictionary of productions; this has sub-productions that
# should be further expanded at a later step before converting into lark # should be further expanded at a later step before converting into lark
# code # code
productions_hi: Dict[str, Union[str, List[str]]] = dict() productions_hi: Dict[str, Union[str, List[str]]] = dict()
# TODO: this should probably not be inlined here, but i'll move it out once for node_desc in self.node_map.values():
# i get more info into the 'env'
def collect_required_thunks(
env: List[Tuple[str, NodeRef]], expr: Expr
) -> Dict[str, str]:
names = dict(env)
if isinstance(expr, ExprDot):
return collect_required_thunks(env, expr.left)
elif isinstance(expr, ExprMul):
a = collect_required_thunks(env, expr.left)
b = collect_required_thunks(env, expr.right)
a.update(b)
return a
elif isinstance(expr, ExprAdd):
a = collect_required_thunks(env, expr.left)
b = collect_required_thunks(env, expr.right)
a.update(b)
return a
elif isinstance(expr, ExprCall):
return collect_required_thunks(env, expr.func)
elif isinstance(expr, ExprName):
if expr.name not in names and expr.name not in builtins:
raise Exception(f"unbound name '{expr.name}'")
return dict()
raise Exception(f"unhandled {expr.__class__}")
node_map: Dict[str, NodeDesc] = dict()
for _node in filter(lambda c: isinstance(c, Node), self.program):
nd = NodeDesc(cast(Node, _node))
node_map[_node.name] = nd
self.nonterminal_map[nd.name] = nd.nonterminal
for node_desc in node_map.values():
assert isinstance(node_desc, NodeDesc) assert isinstance(node_desc, NodeDesc)
self.starts.add(node_desc.nonterminal) self.starts.add(node_desc.nonterminal)
class_fields = [] class_fields = []
for field_name, field_ty in what_fields[node_desc.name].items(): for field_name, field_ty in self.what_fields[node_desc.name].items():
class_fields.append(f"{field_name}: Thunk[{field_ty}]") class_fields.append(f"{field_name}: Thunk[{field_ty}]")
g = textwrap.indent("\n".join(class_fields), " ") g = textwrap.indent("\n".join(class_fields), " ")
@ -181,31 +208,10 @@ class GenResult:
prod_name = gensym(node_desc.nonterminal + "_") prod_name = gensym(node_desc.nonterminal + "_")
# print("PRODUCTION", prod_name, variant.prod) # print("PRODUCTION", prod_name, variant.prod)
# resolving a production just means checking to make sure it's a
# type that exists or it's a regex
def resolve_production(sym: Sym) -> str:
if isinstance(sym, SymRename):
if isinstance(sym.ty, NodeRefByName):
if sym.ty.name in node_map:
return node_map[sym.ty.name].nonterminal
else:
raise Exception(
f"unresolved name {sym.ty.name} in production"
)
elif isinstance(sym.ty, NodeRegex):
sym_name = gensym("sym")
self.literals[sym_name] = f"/{sym.ty.pat.pattern}/"
return sym_name
elif isinstance(sym, SymLit):
sym_name = gensym("lit")
# hack to make repr have double quotes
self.literals[sym_name] = json.dumps(sym.lit)
return sym_name
raise Exception(f"unhandled {sym.__class__}")
seq = [] seq = []
for sym in variant.prod: for sym in variant.prod:
n = resolve_production(sym) n = self._resolve_production(sym)
seq.append(n) seq.append(n)
self.parse_rules[node_desc.nonterminal].append( self.parse_rules[node_desc.nonterminal].append(
ParseEquation(prod_name, seq, v_class_name) ParseEquation(prod_name, seq, v_class_name)
@ -217,14 +223,10 @@ class GenResult:
for sym in variant.prod: for sym in variant.prod:
if isinstance(sym, SymRename): if isinstance(sym, SymRename):
env.append((sym.name, sym.ty)) env.append((sym.name, sym.ty))
print(env)
# for each of the equations, find out what the equation is trying # for each of the equations, find out what the equation is
# to compute, and generate a thunk corresponding to that value. # trying to compute, and generate a thunk corresponding to that
# value.
for eq in variant.equations: for eq in variant.equations:
eq_name = gensym(f"eq_{node_desc.name}") self._collect_required_thunks(copy.deepcopy(env), eq.rhs)
thunk_name = gensym(f"thunk_{node_desc.name}")
print("RHS", eq.rhs, eq.rhs.id)
collect_required_thunks(copy.deepcopy(env), eq.rhs)