more reorganizing

This commit is contained in:
Michael Zhang 2021-09-30 17:49:26 -05:00
parent 65eeb7d25d
commit 9dbb557d00
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B
2 changed files with 27 additions and 25 deletions

View file

@ -24,14 +24,14 @@ class NodeDesc:
def __init__(self, node: Node): def __init__(self, node: Node):
self.node = node self.node = node
self.name = node.name self.name = node.name
self.nonterminal = node.name.lower() self.nonterminal = gensym(node.name.lower())
class ParseEquation: class ParseEquation:
def __init__(self, name: str, syms: List[str], ty: str): def __init__(self, name: str, syms: List[str], pyty: str):
self.name = name self.name = name
self.syms = syms self.syms = syms
self.ty = ty self.ty = pyty
class GenResult: class GenResult:
@ -40,23 +40,25 @@ class GenResult:
self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list) self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list)
self.starts: Set[str] = set() self.starts: Set[str] = set()
self.extra = ex self.extra = ex
self.trans: List[str] = list() self.nonterminal_map: Dict[str, str] = dict()
@property @property
def transdef(self) -> str: def transdef(self) -> str:
s = self.trans s = []
for name, rules in self.parse_rules.items(): for name, rules in self.parse_rules.items():
n = name.lstrip("?") n = name.lstrip("?")
for equation in rules: for equation in rules:
code = f""" code = textwrap.dedent(f"""
def {equation.name}(self, items: Any) -> Thunk[{equation.ty}]: def {equation.name}(self, items: Any) -> Thunk[{equation.ty}]:
return Thunk(lambda: {equation.ty}()) def inner() -> {equation.ty}:
""".strip().replace("\n", "") res = {equation.ty}()
code = re.sub(r"\s+", " ", code) return res
return Thunk(inner)
""")
s.append(code) s.append(code)
if not s: if not s:
s = ["pass"] s = ["pass"]
return "\n" + "\n".join(map(lambda c: f" {c}", s)) return textwrap.indent("\n".join(s), " ")
@property @property
def parser_data(self) -> str: def parser_data(self) -> str:
@ -143,22 +145,16 @@ def gen(program: List[Decl]) -> GenResult:
return dict() return dict()
raise Exception(f"unhandled {expr.__class__}") raise Exception(f"unhandled {expr.__class__}")
node_map: Dict[str, NodeDesc] = dict( node_map: Dict[str, NodeDesc] = dict()
map(
lambda n: (n.name, NodeDesc(cast(Node, n))), for _node in filter(lambda c: isinstance(c, Node), program):
filter(lambda c: isinstance(c, Node), program), nd = NodeDesc(cast(Node, _node))
) node_map[_node.name] = nd
) res.nonterminal_map[nd.name] = nd.nonterminal
for node_desc in node_map.values(): for node_desc in node_map.values():
assert isinstance(node_desc, NodeDesc) assert isinstance(node_desc, NodeDesc)
res.starts.add(node_desc.nonterminal)
res.starts.add(node_desc.name.lower())
# res.parse_rules[f"?{node_desc.name.lower()}"].append(
# ParseEquation(
# node_desc.name.lower(), [node_desc.nonterminal], node_desc.nonterminal
# )
# )
class_decl = textwrap.dedent( class_decl = textwrap.dedent(
f""" f"""
@ -210,7 +206,9 @@ def gen(program: List[Decl]) -> GenResult:
for sym in variant.prod: for sym in variant.prod:
n = resolve_production(sym) n = resolve_production(sym)
seq.append(n) seq.append(n)
res.parse_rules[node_desc.nonterminal].append(ParseEquation(prod_name, seq, v_class_name)) res.parse_rules[node_desc.nonterminal].append(
ParseEquation(prod_name, seq, v_class_name)
)
# create an environment for checking the equations based on # create an environment for checking the equations based on
# the production # the production

View file

@ -1,5 +1,6 @@
import textwrap import textwrap
import os import os
import json
import importlib import importlib
from lark import Lark from lark import Lark
@ -55,7 +56,9 @@ if __name__ == "__main__":
class Trans(Transformer[None]): {transdef} class Trans(Transformer[None]): {transdef}
__agNonterminals = {ntmap}
def parse(input: str, start: Optional[str] = None) -> Any: def parse(input: str, start: Optional[str] = None) -> Any:
if start is not None: start = __agNonterminals[start]
tree = parser.parse(input, start) tree = parser.parse(input, start)
trans = Trans() trans = Trans()
res = trans.transform(tree) res = trans.transform(tree)
@ -68,8 +71,9 @@ if __name__ == "__main__":
ex=res.extra, ex=res.extra,
starts=list(res.starts), starts=list(res.starts),
transdef=res.transdef, transdef=res.transdef,
ntmap=json.dumps(res.nonterminal_map),
) )
) )
mod = importlib.import_module("gen.arith") mod = importlib.import_module("gen.arith")
print(mod.parse("1 + 2 * 3", start="expr")) # type: ignore print(mod.parse("1 + 2 * 3", start="Expr")) # type: ignore