more reorganizing

This commit is contained in:
Michael Zhang 2021-09-30 17:49:26 -05:00
parent 65eeb7d25d
commit 9dbb557d00
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B
2 changed files with 27 additions and 25 deletions

View file

@ -24,14 +24,14 @@ class NodeDesc:
def __init__(self, node: Node):
self.node = node
self.name = node.name
self.nonterminal = node.name.lower()
self.nonterminal = gensym(node.name.lower())
class ParseEquation:
def __init__(self, name: str, syms: List[str], ty: str):
def __init__(self, name: str, syms: List[str], pyty: str):
self.name = name
self.syms = syms
self.ty = ty
self.ty = pyty
class GenResult:
@ -40,23 +40,25 @@ class GenResult:
self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list)
self.starts: Set[str] = set()
self.extra = ex
self.trans: List[str] = list()
self.nonterminal_map: Dict[str, str] = dict()
@property
def transdef(self) -> str:
s = self.trans
s = []
for name, rules in self.parse_rules.items():
n = name.lstrip("?")
for equation in rules:
code = f"""
code = textwrap.dedent(f"""
def {equation.name}(self, items: Any) -> Thunk[{equation.ty}]:
return Thunk(lambda: {equation.ty}())
""".strip().replace("\n", "")
code = re.sub(r"\s+", " ", code)
def inner() -> {equation.ty}:
res = {equation.ty}()
return res
return Thunk(inner)
""")
s.append(code)
if not s:
s = ["pass"]
return "\n" + "\n".join(map(lambda c: f" {c}", s))
return textwrap.indent("\n".join(s), " ")
@property
def parser_data(self) -> str:
@ -143,22 +145,16 @@ def gen(program: List[Decl]) -> GenResult:
return dict()
raise Exception(f"unhandled {expr.__class__}")
node_map: Dict[str, NodeDesc] = dict(
map(
lambda n: (n.name, NodeDesc(cast(Node, n))),
filter(lambda c: isinstance(c, Node), program),
)
)
node_map: Dict[str, NodeDesc] = dict()
for _node in filter(lambda c: isinstance(c, Node), program):
nd = NodeDesc(cast(Node, _node))
node_map[_node.name] = nd
res.nonterminal_map[nd.name] = nd.nonterminal
for node_desc in node_map.values():
assert isinstance(node_desc, NodeDesc)
res.starts.add(node_desc.name.lower())
# res.parse_rules[f"?{node_desc.name.lower()}"].append(
# ParseEquation(
# node_desc.name.lower(), [node_desc.nonterminal], node_desc.nonterminal
# )
# )
res.starts.add(node_desc.nonterminal)
class_decl = textwrap.dedent(
f"""
@ -210,7 +206,9 @@ def gen(program: List[Decl]) -> GenResult:
for sym in variant.prod:
n = resolve_production(sym)
seq.append(n)
res.parse_rules[node_desc.nonterminal].append(ParseEquation(prod_name, seq, v_class_name))
res.parse_rules[node_desc.nonterminal].append(
ParseEquation(prod_name, seq, v_class_name)
)
# create an environment for checking the equations based on
# the production

View file

@ -1,5 +1,6 @@
import textwrap
import os
import json
import importlib
from lark import Lark
@ -55,7 +56,9 @@ if __name__ == "__main__":
class Trans(Transformer[None]): {transdef}
__agNonterminals = {ntmap}
def parse(input: str, start: Optional[str] = None) -> Any:
if start is not None: start = __agNonterminals[start]
tree = parser.parse(input, start)
trans = Trans()
res = trans.transform(tree)
@ -68,8 +71,9 @@ if __name__ == "__main__":
ex=res.extra,
starts=list(res.starts),
transdef=res.transdef,
ntmap=json.dumps(res.nonterminal_map),
)
)
mod = importlib.import_module("gen.arith")
print(mod.parse("1 + 2 * 3", start="expr")) # type: ignore
print(mod.parse("1 + 2 * 3", start="Expr")) # type: ignore