more reorganizing
This commit is contained in:
parent
65eeb7d25d
commit
9dbb557d00
2 changed files with 27 additions and 25 deletions
46
aggen.py
46
aggen.py
|
@ -24,14 +24,14 @@ class NodeDesc:
|
||||||
def __init__(self, node: Node):
|
def __init__(self, node: Node):
|
||||||
self.node = node
|
self.node = node
|
||||||
self.name = node.name
|
self.name = node.name
|
||||||
self.nonterminal = node.name.lower()
|
self.nonterminal = gensym(node.name.lower())
|
||||||
|
|
||||||
|
|
||||||
class ParseEquation:
|
class ParseEquation:
|
||||||
def __init__(self, name: str, syms: List[str], ty: str):
|
def __init__(self, name: str, syms: List[str], pyty: str):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.syms = syms
|
self.syms = syms
|
||||||
self.ty = ty
|
self.ty = pyty
|
||||||
|
|
||||||
|
|
||||||
class GenResult:
|
class GenResult:
|
||||||
|
@ -40,23 +40,25 @@ class GenResult:
|
||||||
self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list)
|
self.parse_rules: defaultdict[str, List[ParseEquation]] = defaultdict(list)
|
||||||
self.starts: Set[str] = set()
|
self.starts: Set[str] = set()
|
||||||
self.extra = ex
|
self.extra = ex
|
||||||
self.trans: List[str] = list()
|
self.nonterminal_map: Dict[str, str] = dict()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def transdef(self) -> str:
|
def transdef(self) -> str:
|
||||||
s = self.trans
|
s = []
|
||||||
for name, rules in self.parse_rules.items():
|
for name, rules in self.parse_rules.items():
|
||||||
n = name.lstrip("?")
|
n = name.lstrip("?")
|
||||||
for equation in rules:
|
for equation in rules:
|
||||||
code = f"""
|
code = textwrap.dedent(f"""
|
||||||
def {equation.name}(self, items: Any) -> Thunk[{equation.ty}]:
|
def {equation.name}(self, items: Any) -> Thunk[{equation.ty}]:
|
||||||
return Thunk(lambda: {equation.ty}())
|
def inner() -> {equation.ty}:
|
||||||
""".strip().replace("\n", "")
|
res = {equation.ty}()
|
||||||
code = re.sub(r"\s+", " ", code)
|
return res
|
||||||
|
return Thunk(inner)
|
||||||
|
""")
|
||||||
s.append(code)
|
s.append(code)
|
||||||
if not s:
|
if not s:
|
||||||
s = ["pass"]
|
s = ["pass"]
|
||||||
return "\n" + "\n".join(map(lambda c: f" {c}", s))
|
return textwrap.indent("\n".join(s), " ")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parser_data(self) -> str:
|
def parser_data(self) -> str:
|
||||||
|
@ -143,22 +145,16 @@ def gen(program: List[Decl]) -> GenResult:
|
||||||
return dict()
|
return dict()
|
||||||
raise Exception(f"unhandled {expr.__class__}")
|
raise Exception(f"unhandled {expr.__class__}")
|
||||||
|
|
||||||
node_map: Dict[str, NodeDesc] = dict(
|
node_map: Dict[str, NodeDesc] = dict()
|
||||||
map(
|
|
||||||
lambda n: (n.name, NodeDesc(cast(Node, n))),
|
for _node in filter(lambda c: isinstance(c, Node), program):
|
||||||
filter(lambda c: isinstance(c, Node), program),
|
nd = NodeDesc(cast(Node, _node))
|
||||||
)
|
node_map[_node.name] = nd
|
||||||
)
|
res.nonterminal_map[nd.name] = nd.nonterminal
|
||||||
|
|
||||||
for node_desc in node_map.values():
|
for node_desc in node_map.values():
|
||||||
assert isinstance(node_desc, NodeDesc)
|
assert isinstance(node_desc, NodeDesc)
|
||||||
|
res.starts.add(node_desc.nonterminal)
|
||||||
res.starts.add(node_desc.name.lower())
|
|
||||||
# res.parse_rules[f"?{node_desc.name.lower()}"].append(
|
|
||||||
# ParseEquation(
|
|
||||||
# node_desc.name.lower(), [node_desc.nonterminal], node_desc.nonterminal
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
|
|
||||||
class_decl = textwrap.dedent(
|
class_decl = textwrap.dedent(
|
||||||
f"""
|
f"""
|
||||||
|
@ -210,7 +206,9 @@ def gen(program: List[Decl]) -> GenResult:
|
||||||
for sym in variant.prod:
|
for sym in variant.prod:
|
||||||
n = resolve_production(sym)
|
n = resolve_production(sym)
|
||||||
seq.append(n)
|
seq.append(n)
|
||||||
res.parse_rules[node_desc.nonterminal].append(ParseEquation(prod_name, seq, v_class_name))
|
res.parse_rules[node_desc.nonterminal].append(
|
||||||
|
ParseEquation(prod_name, seq, v_class_name)
|
||||||
|
)
|
||||||
|
|
||||||
# create an environment for checking the equations based on
|
# create an environment for checking the equations based on
|
||||||
# the production
|
# the production
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import importlib
|
import importlib
|
||||||
from lark import Lark
|
from lark import Lark
|
||||||
|
|
||||||
|
@ -55,7 +56,9 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
class Trans(Transformer[None]): {transdef}
|
class Trans(Transformer[None]): {transdef}
|
||||||
|
|
||||||
|
__agNonterminals = {ntmap}
|
||||||
def parse(input: str, start: Optional[str] = None) -> Any:
|
def parse(input: str, start: Optional[str] = None) -> Any:
|
||||||
|
if start is not None: start = __agNonterminals[start]
|
||||||
tree = parser.parse(input, start)
|
tree = parser.parse(input, start)
|
||||||
trans = Trans()
|
trans = Trans()
|
||||||
res = trans.transform(tree)
|
res = trans.transform(tree)
|
||||||
|
@ -68,8 +71,9 @@ if __name__ == "__main__":
|
||||||
ex=res.extra,
|
ex=res.extra,
|
||||||
starts=list(res.starts),
|
starts=list(res.starts),
|
||||||
transdef=res.transdef,
|
transdef=res.transdef,
|
||||||
|
ntmap=json.dumps(res.nonterminal_map),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
mod = importlib.import_module("gen.arith")
|
mod = importlib.import_module("gen.arith")
|
||||||
print(mod.parse("1 + 2 * 3", start="expr")) # type: ignore
|
print(mod.parse("1 + 2 * 3", start="Expr")) # type: ignore
|
||||||
|
|
Loading…
Reference in a new issue