more reorganizing
This commit is contained in:
parent
65eeb7d25d
commit
9dbb557d00
2 changed files with 27 additions and 25 deletions
46
aggen.py
46
aggen.py
|
@ -24,14 +24,14 @@ class NodeDesc:
|
|||
def __init__(self, node: Node):
|
||||
self.node = node
|
||||
self.name = node.name
|
||||
self.nonterminal = node.name.lower()
|
||||
self.nonterminal = gensym(node.name.lower())
|
||||
|
||||
|
||||
class ParseEquation:
|
||||
def __init__(self, name: str, syms: List[str], ty: str):
|
||||
def __init__(self, name: str, syms: List[str], pyty: str):
|
||||
self.name = name
|
||||
self.syms = syms
|
||||
self.ty = ty
|
||||
self.ty = pyty
|
||||
|
||||
|
||||
class GenResult:
|
||||
|
@ -40,23 +40,25 @@ class GenResult:
|
|||
self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list)
|
||||
self.starts: Set[str] = set()
|
||||
self.extra = ex
|
||||
self.trans: List[str] = list()
|
||||
self.nonterminal_map: Dict[str, str] = dict()
|
||||
|
||||
@property
|
||||
def transdef(self) -> str:
|
||||
s = self.trans
|
||||
s = []
|
||||
for name, rules in self.parse_rules.items():
|
||||
n = name.lstrip("?")
|
||||
for equation in rules:
|
||||
code = f"""
|
||||
code = textwrap.dedent(f"""
|
||||
def {equation.name}(self, items: Any) -> Thunk[{equation.ty}]:
|
||||
return Thunk(lambda: {equation.ty}())
|
||||
""".strip().replace("\n", "")
|
||||
code = re.sub(r"\s+", " ", code)
|
||||
def inner() -> {equation.ty}:
|
||||
res = {equation.ty}()
|
||||
return res
|
||||
return Thunk(inner)
|
||||
""")
|
||||
s.append(code)
|
||||
if not s:
|
||||
s = ["pass"]
|
||||
return "\n" + "\n".join(map(lambda c: f" {c}", s))
|
||||
return textwrap.indent("\n".join(s), " ")
|
||||
|
||||
@property
|
||||
def parser_data(self) -> str:
|
||||
|
@ -143,22 +145,16 @@ def gen(program: List[Decl]) -> GenResult:
|
|||
return dict()
|
||||
raise Exception(f"unhandled {expr.__class__}")
|
||||
|
||||
node_map: Dict[str, NodeDesc] = dict(
|
||||
map(
|
||||
lambda n: (n.name, NodeDesc(cast(Node, n))),
|
||||
filter(lambda c: isinstance(c, Node), program),
|
||||
)
|
||||
)
|
||||
node_map: Dict[str, NodeDesc] = dict()
|
||||
|
||||
for _node in filter(lambda c: isinstance(c, Node), program):
|
||||
nd = NodeDesc(cast(Node, _node))
|
||||
node_map[_node.name] = nd
|
||||
res.nonterminal_map[nd.name] = nd.nonterminal
|
||||
|
||||
for node_desc in node_map.values():
|
||||
assert isinstance(node_desc, NodeDesc)
|
||||
|
||||
res.starts.add(node_desc.name.lower())
|
||||
# res.parse_rules[f"?{node_desc.name.lower()}"].append(
|
||||
# ParseEquation(
|
||||
# node_desc.name.lower(), [node_desc.nonterminal], node_desc.nonterminal
|
||||
# )
|
||||
# )
|
||||
res.starts.add(node_desc.nonterminal)
|
||||
|
||||
class_decl = textwrap.dedent(
|
||||
f"""
|
||||
|
@ -210,7 +206,9 @@ def gen(program: List[Decl]) -> GenResult:
|
|||
for sym in variant.prod:
|
||||
n = resolve_production(sym)
|
||||
seq.append(n)
|
||||
res.parse_rules[node_desc.nonterminal].append(ParseEquation(prod_name, seq, v_class_name))
|
||||
res.parse_rules[node_desc.nonterminal].append(
|
||||
ParseEquation(prod_name, seq, v_class_name)
|
||||
)
|
||||
|
||||
# create an environment for checking the equations based on
|
||||
# the production
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import textwrap
|
||||
import os
|
||||
import json
|
||||
import importlib
|
||||
from lark import Lark
|
||||
|
||||
|
@ -55,7 +56,9 @@ if __name__ == "__main__":
|
|||
|
||||
class Trans(Transformer[None]): {transdef}
|
||||
|
||||
__agNonterminals = {ntmap}
|
||||
def parse(input: str, start: Optional[str] = None) -> Any:
|
||||
if start is not None: start = __agNonterminals[start]
|
||||
tree = parser.parse(input, start)
|
||||
trans = Trans()
|
||||
res = trans.transform(tree)
|
||||
|
@ -68,8 +71,9 @@ if __name__ == "__main__":
|
|||
ex=res.extra,
|
||||
starts=list(res.starts),
|
||||
transdef=res.transdef,
|
||||
ntmap=json.dumps(res.nonterminal_map),
|
||||
)
|
||||
)
|
||||
|
||||
mod = importlib.import_module("gen.arith")
|
||||
print(mod.parse("1 + 2 * 3", start="expr")) # type: ignore
|
||||
print(mod.parse("1 + 2 * 3", start="Expr")) # type: ignore
|
||||
|
|
Loading…
Reference in a new issue