diff --git a/Pipfile b/Pipfile index 28551c9..8255492 100644 --- a/Pipfile +++ b/Pipfile @@ -4,10 +4,11 @@ verify_ssl = true name = "pypi" [packages] -lark-parser = "==0.11.3" +lark-parser = "*" [dev-packages] -mypy = "==0.901" +mypy = "*" +black = "*" [requires] python_version = "3.9" diff --git a/Pipfile.lock b/Pipfile.lock index 2d28c65..6432c39 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "225e01979b7cd88f078f6c030af46ff11f7af0bf8c34a7deeb50be649492cd4d" + "sha256": "98b6f8b73b8482a118dfcbb2dca537038551c3a1b7c8cb2d1762b9ae9a71dfdc" }, "pipfile-spec": 6, "requires": { @@ -18,64 +18,12 @@ "default": { "lark-parser": { "hashes": [ - "sha256:e29ca814a98bb0f81674617d878e5f611cb993c19ea47f22c80da3569425f9bd" + "sha256:0eaf30cb5ba787fe404d73a7d6e61df97b21d5a63ac26c5008c78a494373c675", + "sha256:15967db1f1214013dca65b1180745047b9be457d73da224fcda3d9dd4e96a138" ], "index": "pypi", - "version": "==0.11.3" + "version": "==0.12.0" } }, - "develop": { - "mypy": { - "hashes": [ - "sha256:053b92ebae901fc7954677949049f70133f2f63e3e83dc100225c26d6a46fe95", - "sha256:08cf1f31029612e1008a9432337ca4b1fbac989ff7c8200e2c9ec42705cd4c7b", - "sha256:18753a8bb9bcf031ff10009852bd48d781798ecbccf45be5449892e6af4e3f9f", - "sha256:1cd241966a35036f936d4739bd71a1c64e15f02bf7d12bb2815cccfb2993a9de", - "sha256:307a6c047596d768c3d689734307e47a91596eb9dbb67cfdf7d1fd9117b27f13", - "sha256:4a622faa3be76114cdce009f8ec173401494cf9e8f22713e7ae75fee9d906ab3", - "sha256:4b54518e399c3f4dc53380d4252c83276b2e60623cfc5274076eb8aae57572ac", - "sha256:5ddd8f4096d5fc2e7d7bb3924ac22758862163ad2c1cdc902c4b85568160e90a", - "sha256:61b10ba18a01d05fc46adbf4f18b0e92178f6b5fd0f45926ffc2a408b5419728", - "sha256:7845ad3a31407bfbd64c76d032c16ab546d282930f747023bf07c17b054bebc5", - "sha256:79beb6741df15395908ecc706b3a593a98804c1d5b5b6bd0c5b03b67c7ac03a0", - "sha256:8183561bfd950e93eeab8379ae5ec65873c856f5b58498d23aa8691f74c86030", - "sha256:91211acf1485a1db0b1261bc5f9ed450cba3c0dfd8da0a6680e94827591e34d7", - "sha256:97be0e8ed116f7f79472a49cf06dd45dd806771142401f684d4f13ee652a63c0", - "sha256:9941b685807b60c58020bb67b3217c9df47820dcd00425f55cdf71f31d3c42d9", - "sha256:a85c6759dcc6a9884131fa06a037bd34352aa3947e7f5d9d5a35652cc3a44bcd", - "sha256:bc61153eb4df769538bb4a6e1045f59c2e6119339690ec719feeacbfc3809e89", - "sha256:bf347c327c48d963bdef5bf365215d3e98b5fddbe5069fc796cec330e8235a20", - "sha256:c86e3f015bfe7958646825d41c0691c6e5a5cd4015e3409b5c29c18a3c712534", - "sha256:c8bc628961cca4335ac7d1f2ed59b7125d9252fe4c78c3d66d30b50162359c99", - "sha256:da914faaa80c25f463913da6db12adba703822a768f452f29f75b40bb4357139", - "sha256:e8577d30daf1b7b6582020f539f76e78ee1ed64a0323b28c8e0333c45db9369f", - "sha256:f208cc967e566698c4e30a1f65843fc88d8da05a8693bac8b975417e0aee9ced" - ], - "index": "pypi", - "version": "==0.901" - }, - "mypy-extensions": { - "hashes": [ - "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d", - "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8" - ], - "version": "==0.4.3" - }, - "toml": { - "hashes": [ - "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", - "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f" - ], - "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", - "version": "==0.10.2" - }, - "typing-extensions": { - "hashes": [ - "sha256:0ac0f89795dd19de6b97debb0c6af1c70987fd80a2d62d1958f7e56fcc31b497", - "sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342", - "sha256:779383f6086d90c99ae41cf0ff39aac8a7937a9283ce0a414e5dd782f4c94a84" - ], - "version": "==3.10.0.0" - } - } + "develop": {} } diff --git a/agast.py b/agast.py index c92fa51..b7bc619 100644 --- a/agast.py +++ b/agast.py @@ -3,6 +3,7 @@ from lark import Transformer, Tree, Token import re from re import Pattern + def unescape(s: str) -> str: q = s[0] t = "" @@ -11,11 +12,15 @@ def unescape(s: str) -> str: while i < len(s): c = s[i] if escaped: - if c == q: t += q - elif c == 'n': t += '\n' - elif c == 't': t += '\t' - if c == q: break - if c == '\\': + if c == q: + t += q + elif c == "n": + t += "\n" + elif c == "t": + t += "\t" + if c == q: + break + if c == "\\": escaped = True i += 1 continue @@ -23,8 +28,10 @@ def unescape(s: str) -> str: i += 1 return t + T = TypeVar("T") + class Ast: # def __new__(cls: Type[Ast], name: str, bases: Tuple[type], namespace: Dict[str, Any]) -> Ast: # x = super().__new__(cls, name, bases, namespace) @@ -32,62 +39,100 @@ class Ast: # return x id: str n = 0 + @classmethod def __gen(cls, name: str = "") -> str: newid = cls.n cls.n += 1 return f"_a{newid}{name}" + def __init__(self) -> None: self.id = self.__gen() + class Decl: name: str -class IfaceRef(str): pass + +class IfaceRef(str): + pass + + class IfaceField: def __init__(self, name: str, ty: str): self.name = name self.ty = ty + + class Iface(Decl): def __init__(self, name: str, fields: List[IfaceField]): self.name = name self.fields = fields + class Expr(Ast): def __init__(self) -> None: super().__init__() -class NodeRef: pass + +class NodeRef: + pass + + class NodeRefByName(NodeRef, str): def __init__(self, name: str): self.name = name - def __repr__(self) -> str: return f"NodeRefByName({self.name})" + + def __repr__(self) -> str: + return f"NodeRefByName({self.name})" + + class NodeRegex(NodeRef): def __init__(self, pat: str): self.pat = re.compile(unescape(pat)) - def __repr__(self) -> str: return f"NodeRegex({self.pat.pattern})" -class Sym: pass + def __repr__(self) -> str: + return f"NodeRegex({self.pat.pattern})" + + +class Sym: + pass + + class SymLit(Sym): def __init__(self, s: str): self.lit = unescape(s) - def __repr__(self) -> str: return f"SymLit({repr(self.lit)})" + + def __repr__(self) -> str: + return f"SymLit({repr(self.lit)})" + + class SymRename(Sym): def __init__(self, name: str, ty: NodeRef): self.name = name self.ty = ty - def __repr__(self) -> str: return f"SymRename({self.name} : {self.ty})" + + def __repr__(self) -> str: + return f"SymRename({self.name} : {self.ty})" + + class Equation: def __init__(self, lhs: Expr, rhs: Expr): self.lhs = lhs self.rhs = rhs - def __repr__(self) -> str: return f"{self.lhs} = {self.rhs}" + + def __repr__(self) -> str: + return f"{self.lhs} = {self.rhs}" + class Variant: def __init__(self, prod: List[Sym], equations: List[Equation]): self.prod = prod self.equations = equations - def __repr__(self) -> str: return f"Variant({self.prod}, {self.equations})" + + def __repr__(self) -> str: + return f"Variant({self.prod}, {self.equations})" + class Node(Decl): def __init__(self, name: str, ifaces: List[IfaceRef], variants: List[Variant]): @@ -95,90 +140,137 @@ class Node(Decl): self.ifaces = ifaces self.variants = variants + class ExprDot(Expr): def __init__(self, left: Expr, right: str): super().__init__() self.left = left self.right = right - def __repr__(self) -> str: return f"{self.left}.{self.right}" + + def __repr__(self) -> str: + return f"{self.left}.{self.right}" + + class ExprAdd(Expr): def __init__(self, left: Expr, right: Expr): super().__init__() self.left = left self.right = right - def __repr__(self) -> str: return f"{self.left} + {self.right}" + + def __repr__(self) -> str: + return f"{self.left} + {self.right}" + + class ExprMul(Expr): def __init__(self, left: Expr, right: Expr): super().__init__() self.left = left self.right = right - def __repr__(self) -> str: return f"{self.left} * {self.right}" + + def __repr__(self) -> str: + return f"{self.left} * {self.right}" + + class ExprCall(Expr): def __init__(self, func: Expr, args: List[Expr]): super().__init__() self.func = func self.args = args - def __repr__(self) -> str: return f"{self.func}({self.args})" + + def __repr__(self) -> str: + return f"{self.func}({self.args})" + + class ExprName(Expr): def __init__(self, name: str): super().__init__() self.name = name - def __repr__(self) -> str: return f"{self.name}" + + def __repr__(self) -> str: + return f"{self.name}" + class Parser(Transformer[List[Decl]]): - def program(self, items: List[Decl]) -> List[Decl]: return items + def program(self, items: List[Decl]) -> List[Decl]: + return items # interfaces def iface(self, items: List[Any]) -> Iface: [name, fields] = items return Iface(name, fields) + def iface_field(self, items: List[str]) -> IfaceField: [name, ty] = items return IfaceField(name, ty) - def iface_ref(self, items: List[str]) -> str: return items[0] - def iface_refs(self, items: List[IfaceRef]) -> List[IfaceRef]: return items + + def iface_ref(self, items: List[str]) -> str: + return items[0] + + def iface_refs(self, items: List[IfaceRef]) -> List[IfaceRef]: + return items # nodes def node(self, items: List[Any]) -> Node: [name, ifaces, variants] = items return Node(name, ifaces, variants) - def node_ref_name(self, items: List[str]) -> NodeRefByName: return NodeRefByName(items[0]) - def node_regex(self, items: List[str]) -> NodeRegex: return NodeRegex(items[0]) + + def node_ref_name(self, items: List[str]) -> NodeRefByName: + return NodeRefByName(items[0]) + + def node_regex(self, items: List[str]) -> NodeRegex: + return NodeRegex(items[0]) # variants - def variants(self, items: List[Variant]) -> List[Variant]: return items + def variants(self, items: List[Variant]) -> List[Variant]: + return items + def variant(self, items: List[Any]) -> Variant: [prod, equations] = items return Variant(prod, equations) - def prod(self, items: List[Sym]) -> List[Sym]: return items + + def prod(self, items: List[Sym]) -> List[Sym]: + return items # symbols in productions - def sym_lit(self, items: List[str]) -> Sym: return SymLit(items[0]) - def sym_rename(self, items: List[Any]) -> Sym: return SymRename(items[0], items[1]) + def sym_lit(self, items: List[str]) -> Sym: + return SymLit(items[0]) + + def sym_rename(self, items: List[Any]) -> Sym: + return SymRename(items[0], items[1]) # equations - def equations(self, items: List[Equation]) -> List[Equation]: return items - def equation_semi(self, items: List[Equation]) -> Equation: return items[0] - def equation(self, items: List[Expr]) -> Equation: return Equation(items[0], items[1]) + def equations(self, items: List[Equation]) -> List[Equation]: + return items + + def equation_semi(self, items: List[Equation]) -> Equation: + return items[0] + + def equation(self, items: List[Expr]) -> Equation: + return Equation(items[0], items[1]) # expr def expr_dot(self, items: List[Any]) -> Expr: [left, right] = items return ExprDot(left, right) + def expr_add(self, items: List[Expr]) -> Expr: [left, right] = items return ExprAdd(left, right) + def expr_mul(self, items: List[Expr]) -> Expr: [left, right] = items return ExprMul(left, right) + def expr_call(self, items: List[Expr]) -> Expr: [func, args] = items # TODO: args should be a list of exprs -_ - return ExprCall(func, [args]) + def expr_name(self, items: List[str]) -> Expr: return ExprName(items[0]) def sep_trail(self, items: List[Tree]) -> List[T]: return list(map(lambda it: cast(T, it), items)) - def ident(self, items: List[Token]) -> str: return cast(str, items[0].value) + def ident(self, items: List[Token]) -> str: + return cast(str, items[0].value) diff --git a/aggen.py b/aggen.py index 8b33d7a..924661d 100644 --- a/aggen.py +++ b/aggen.py @@ -1,194 +1,212 @@ -from typing import * -import textwrap -import re -import copy -import json -from collections import defaultdict -from re import Pattern - -from agast import * - -global i -i = 0 - -class GenResult: - def __init__(self, pd: str = "", ex: str = ""): - self.literals: Dict[str, str] = dict() - self.parse_rules: defaultdict[str, List[str]] = defaultdict(list) - self.starts: Set[str] = set() - self.extra = ex - - @property - def parser_data(self) -> str: - s = [] - for sym, pat in self.literals.items(): - s.append(f"{sym}: {pat}") - for name, rules in self.parse_rules.items(): - srules = " | ".join(rules) - s.append(f"{name}: {srules}") - s.append("%import common.WS") - s.append("%ignore WS") - return "\n".join(s) - -def gen(program: List[Decl]) -> GenResult: - res = GenResult() - - def gen(prefix: str = "", suffix: str = "") -> str: - global i - presan = re.sub("[^0-9a-zA-Z]+", "_", prefix) - sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix) - i += 1 - return f"{presan}{i}{sufsan}" - - def v(name: str) -> str: - return f"__ag_{name}" - - # builtins - builtins: Dict[str, str] = { - "parseInt": "", - } - - # collect a list of name -> iface declarations - ifaces: Dict[str, Iface] = dict( - map(lambda c: (c.name, cast(Iface, c)), - filter(lambda c: isinstance(c, Iface), - program))) - - # list of node -> iface mappings - what_ifaces: Dict[str, Set[str]] = dict() - what_fields: Dict[str, Dict[str, str]] = dict() - for node in filter(lambda c: isinstance(c, Node), program): - node = cast(Node, node) - # all_fields = dict() - what_ifaces[node.name] = set(node.ifaces) - this_fields = dict() - for iface in node.ifaces: - fields = ifaces[iface].fields - for field in fields: - if field.name in this_fields: - raise Exception("duplicate field name") - this_fields[field.name] = field.ty - what_fields[node.name] = this_fields - print("what_ifaces:", what_ifaces) - print("what_fields:", what_fields) - - # a high-level dictionary of productions; this has sub-productions - # that should be further expanded at a later step before converting - # into lark code - productions_hi: Dict[str, Union[str, List[str]]] = dict() - - # TODO: this should probably not be inlined here, but i'll move it - # out once i get more info into the 'env' - def collect_required_thunks(env: List[Tuple[str, NodeRef]], expr: Expr) -> Dict[str, str]: - names = dict(env) - print(f"collect_required_thunks({expr})", expr.__class__) - if isinstance(expr, ExprDot): - return collect_required_thunks(env, expr.left) - elif isinstance(expr, ExprMul): - a = collect_required_thunks(env, expr.left) - b = collect_required_thunks(env, expr.right) - a.update(b) - return a - elif isinstance(expr, ExprAdd): - a = collect_required_thunks(env, expr.left) - b = collect_required_thunks(env, expr.right) - a.update(b) - return a - elif isinstance(expr, ExprCall): - return collect_required_thunks(env, expr.func) - elif isinstance(expr, ExprName): - if expr.name not in names and expr.name not in builtins: - raise Exception(f"unbound name '{expr.name}'") - return dict() - raise Exception(f"unhandled {expr.__class__}") - - node_map = dict(map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program))) - node_name_map = dict(map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items())) - - for node in node_map.values(): - node = cast(Node, node) - node_name_lower = node.name.lower() - node_name = node_name_map[node.name] - res.parse_rules[f"?{node_name_lower}"].append(node_name) - res.starts.add(node_name_lower) - - class_decl = textwrap.dedent(f""" - class {v(node_name)}: pass - """) - res.extra += class_decl - - print(node.name, node.ifaces) - - for variant in node.variants: - v_class_name = gen(f"{node_name}_var") - class_decl = textwrap.dedent(f""" - class {v(v_class_name)}({v(node_name)}): - ''' ''' - pass - """) - res.extra += class_decl - - prod_name = gen(node_name + "_") - res.parse_rules[node_name].append(prod_name) - print("PRODUCTION", prod_name, variant.prod) - - # resolving a production just means checking to make sure it's a - # type that exists or it's a regex - def resolve_production(sym: Sym) -> str: - print(f"resolve_production({sym})") - if isinstance(sym, SymRename): - if isinstance(sym.ty, NodeRefByName): - if sym.ty.name in node_name_map: - return node_name_map[sym.ty.name] - else: - raise Exception(f"unresolved name {sym.ty.name} in production") - elif isinstance(sym.ty, NodeRegex): - sym_name = gen("sym") - res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" - return sym_name - elif isinstance(sym, SymLit): - sym_name = gen("lit") - # hack to make repr have double quotes - res.literals[sym_name] = json.dumps(sym.lit) - return sym_name - raise Exception(f"unhandled {sym.__class__}") - - seq = [] - for sym in variant.prod: - n = resolve_production(sym) - seq.append(n) - res.parse_rules[prod_name].append(" ".join(seq)) - - # create an environment for checking the equations based on - # the production - env: List[Tuple[str, NodeRef]] = list() - for sym in variant.prod: - if isinstance(sym, SymRename): - env.append((sym.name, sym.ty)) - print(env) - - # for each of the equations, find out what the equation is - # trying to compute, and generate a thunk corresponding to - # that value. - for eq in variant.equations: - eq_name = gen(f"eq_{node.name}") - thunk_name = gen(f"thunk_{node.name}") - - print("RHS", eq.rhs, eq.rhs.id) - collect_required_thunks(copy.deepcopy(env), eq.rhs) - - func_impl = textwrap.dedent(f""" - def {eq_name}() -> None: - ''' {repr(eq)} ''' - pass - def {thunk_name}() -> Thunk[None]: - return Thunk({eq_name}) - """) - print(f"```py\n{func_impl}\n```") - res.extra += func_impl - - # this is a "type alias" that connects it to one of the generated - # names above - res.extra += f"{node.name} = {v(node_name)}" - - return res +from typing import * +import textwrap +import re +import copy +import json +from collections import defaultdict +from re import Pattern + +from agast import * + +global i +i = 0 + + +class GenResult: + def __init__(self, pd: str = "", ex: str = ""): + self.literals: Dict[str, str] = dict() + self.parse_rules: defaultdict[str, List[str]] = defaultdict(list) + self.starts: Set[str] = set() + self.extra = ex + + @property + def parser_data(self) -> str: + s = [] + for sym, pat in self.literals.items(): + s.append(f"{sym}: {pat}") + for name, rules in self.parse_rules.items(): + srules = " | ".join(rules) + s.append(f"{name}: {srules}") + s.append("%import common.WS") + s.append("%ignore WS") + return "\n".join(s) + + +def gen(program: List[Decl]) -> GenResult: + res = GenResult() + + def gen(prefix: str = "", suffix: str = "") -> str: + global i + presan = re.sub("[^0-9a-zA-Z]+", "_", prefix) + sufsan = re.sub("[^0-9a-zA-Z]+", "_", suffix) + i += 1 + return f"{presan}{i}{sufsan}" + + def v(name: str) -> str: + return f"__ag_{name}" + + # builtins + builtins: Dict[str, str] = { + "parseInt": "", + } + + # collect a list of name -> iface declarations + ifaces: Dict[str, Iface] = dict( + map( + lambda c: (c.name, cast(Iface, c)), + filter(lambda c: isinstance(c, Iface), program), + ) + ) + + # list of node -> iface mappings + what_ifaces: Dict[str, Set[str]] = dict() + what_fields: Dict[str, Dict[str, str]] = dict() + for node in filter(lambda c: isinstance(c, Node), program): + node = cast(Node, node) + # all_fields = dict() + what_ifaces[node.name] = set(node.ifaces) + this_fields = dict() + for iface in node.ifaces: + fields = ifaces[iface].fields + for field in fields: + if field.name in this_fields: + raise Exception("duplicate field name") + this_fields[field.name] = field.ty + what_fields[node.name] = this_fields + print("what_ifaces:", what_ifaces) + print("what_fields:", what_fields) + + # a high-level dictionary of productions; this has sub-productions + # that should be further expanded at a later step before converting + # into lark code + productions_hi: Dict[str, Union[str, List[str]]] = dict() + + # TODO: this should probably not be inlined here, but i'll move it + # out once i get more info into the 'env' + def collect_required_thunks( + env: List[Tuple[str, NodeRef]], expr: Expr + ) -> Dict[str, str]: + names = dict(env) + print(f"collect_required_thunks({expr})", expr.__class__) + if isinstance(expr, ExprDot): + return collect_required_thunks(env, expr.left) + elif isinstance(expr, ExprMul): + a = collect_required_thunks(env, expr.left) + b = collect_required_thunks(env, expr.right) + a.update(b) + return a + elif isinstance(expr, ExprAdd): + a = collect_required_thunks(env, expr.left) + b = collect_required_thunks(env, expr.right) + a.update(b) + return a + elif isinstance(expr, ExprCall): + return collect_required_thunks(env, expr.func) + elif isinstance(expr, ExprName): + if expr.name not in names and expr.name not in builtins: + raise Exception(f"unbound name '{expr.name}'") + return dict() + raise Exception(f"unhandled {expr.__class__}") + + node_map = dict( + map(lambda n: (n.name, n), filter(lambda c: isinstance(c, Node), program)) + ) + node_name_map = dict( + map(lambda n: (n[0], gen(n[1].name.lower())), node_map.items()) + ) + + for node in node_map.values(): + node = cast(Node, node) + node_name_lower = node.name.lower() + node_name = node_name_map[node.name] + res.parse_rules[f"?{node_name_lower}"].append(node_name) + res.starts.add(node_name_lower) + + class_decl = textwrap.dedent( + f""" + class {v(node_name)}: pass + """ + ) + res.extra += class_decl + + print(node.name, node.ifaces) + + for variant in node.variants: + v_class_name = gen(f"{node_name}_var") + class_decl = textwrap.dedent( + f""" + class {v(v_class_name)}({v(node_name)}): + ''' ''' + pass + """ + ) + res.extra += class_decl + + prod_name = gen(node_name + "_") + res.parse_rules[node_name].append(prod_name) + print("PRODUCTION", prod_name, variant.prod) + + # resolving a production just means checking to make sure it's a + # type that exists or it's a regex + def resolve_production(sym: Sym) -> str: + print(f"resolve_production({sym})") + if isinstance(sym, SymRename): + if isinstance(sym.ty, NodeRefByName): + if sym.ty.name in node_name_map: + return node_name_map[sym.ty.name] + else: + raise Exception( + f"unresolved name {sym.ty.name} in production" + ) + elif isinstance(sym.ty, NodeRegex): + sym_name = gen("sym") + res.literals[sym_name] = f"/{sym.ty.pat.pattern}/" + return sym_name + elif isinstance(sym, SymLit): + sym_name = gen("lit") + # hack to make repr have double quotes + res.literals[sym_name] = json.dumps(sym.lit) + return sym_name + raise Exception(f"unhandled {sym.__class__}") + + seq = [] + for sym in variant.prod: + n = resolve_production(sym) + seq.append(n) + res.parse_rules[prod_name].append(" ".join(seq)) + + # create an environment for checking the equations based on + # the production + env: List[Tuple[str, NodeRef]] = list() + for sym in variant.prod: + if isinstance(sym, SymRename): + env.append((sym.name, sym.ty)) + print(env) + + # for each of the equations, find out what the equation is + # trying to compute, and generate a thunk corresponding to + # that value. + for eq in variant.equations: + eq_name = gen(f"eq_{node.name}") + thunk_name = gen(f"thunk_{node.name}") + + print("RHS", eq.rhs, eq.rhs.id) + collect_required_thunks(copy.deepcopy(env), eq.rhs) + + func_impl = textwrap.dedent( + f""" + def {eq_name}() -> None: + ''' {repr(eq)} ''' + pass + def {thunk_name}() -> Thunk[None]: + return Thunk({eq_name}) + """ + ) + print(f"```py\n{func_impl}\n```") + res.extra += func_impl + + # this is a "type alias" that connects it to one of the generated + # names above + res.extra += f"{node.name} = {v(node_name)}" + + return res diff --git a/agmain.py b/agmain.py index 19056d9..644e68b 100644 --- a/agmain.py +++ b/agmain.py @@ -23,7 +23,8 @@ if __name__ == "__main__": if not os.path.exists("gen"): os.makedirs("gen") with open("gen/arith.py", "w") as f: - fmt_str = textwrap.dedent(""" + fmt_str = textwrap.dedent( + """ __all__ = ["parse"] from typing import Generic, TypeVar, Optional, Callable, Dict, Any from lark import Lark, Transformer @@ -47,9 +48,11 @@ if __name__ == "__main__": {ex} def parse(input: str, start: Optional[str] = None) -> Any: return parser.parse(input, start) - """) - f.write(fmt_str.format(pd=res.parser_data, ex=res.extra, starts=list(res.starts))) + """ + ) + f.write( + fmt_str.format(pd=res.parser_data, ex=res.extra, starts=list(res.starts)) + ) mod = importlib.import_module("gen.arith") - print(mod.parse("1 + 2 * 3", start="expr")) # type: ignore - + print(mod.parse("1 + 2 * 3", start="expr")) # type: ignore diff --git a/let.ag b/let.ag index ad02dc0..f4f7a54 100644 --- a/let.ag +++ b/let.ag @@ -1,25 +1,25 @@ -iface HasEnv { - env: Map, -} - -iface HasVal { - val: str, -} - -alias Ident = /([a-zA-Z][a-zA-Z0-9_]*)|(_[a-zA-Z0-9_]+)/ - -node Expr : HasEnv + HasVal { - "let" "=" "in" => { - body.env = self.env.with(name, val); - self.val = body.val; - } - => { - // TODO: does env need to be referenced here? - // TODO: how to check for unbound names ahead of time - // (for self-implementation) - self.val = self.env.lookup(name); - } - => { - self.val = string; - } -} +iface HasEnv { + env: Map, +} + +iface HasVal { + val: str, +} + +alias Ident = /([a-zA-Z][a-zA-Z0-9_]*)|(_[a-zA-Z0-9_]+)/ + +node Expr : HasEnv + HasVal { + "let" "=" "in" => { + body.env = self.env.with(name, val); + self.val = body.val; + } + => { + // TODO: does env need to be referenced here? + // TODO: how to check for unbound names ahead of time + // (for self-implementation) + self.val = self.env.lookup(name); + } + => { + self.val = string; + } +} diff --git a/old.py b/old.py index cc7a5e3..df9c9f1 100644 --- a/old.py +++ b/old.py @@ -2,6 +2,7 @@ from typing import Generic, TypeVar, Optional, Callable T = TypeVar("T") + class Thunk(Generic[T]): value: Optional[T] @@ -14,22 +15,27 @@ class Thunk(Generic[T]): self.value = self.func() return self.value + class Node: value: Thunk[int] + class Add(Node): def __init__(self, left: Node, right: Node): self.value = Thunk(lambda: left.value.get() + right.value.get()) + class Mul(Node): def __init__(self, left: Node, right: Node): self.value = Thunk(lambda: left.value.get() * right.value.get()) + class Lit(Node): def __init__(self, num: int): self.num = num self.value = Thunk(lambda: num) + if __name__ == "__main__": tree = Add(Mul(Lit(3), Lit(4)), Lit(5)) print(tree.value.get()) diff --git a/run.ps1 b/run.ps1 index 4ceda14..df70aaa 100644 --- a/run.ps1 +++ b/run.ps1 @@ -1,16 +1,16 @@ -#blessed https://stackoverflow.com/a/52784160 - -function Invoke-Call { - param ( - [scriptblock]$ScriptBlock, - [string]$ErrorAction = $ErrorActionPreference - ) - & @ScriptBlock - if (($lastexitcode -ne 0) -and $ErrorAction -eq "Stop") { - exit $lastexitcode - } -} - -Invoke-Call -ScriptBlock {mypy (get-item *.py) } -ErrorAction Stop -Invoke-Call -ScriptBlock {python agmain.py } -ErrorAction Stop +#blessed https://stackoverflow.com/a/52784160 + +function Invoke-Call { + param ( + [scriptblock]$ScriptBlock, + [string]$ErrorAction = $ErrorActionPreference + ) + & @ScriptBlock + if (($lastexitcode -ne 0) -and $ErrorAction -eq "Stop") { + exit $lastexitcode + } +} + +Invoke-Call -ScriptBlock {mypy (get-item *.py) } -ErrorAction Stop +Invoke-Call -ScriptBlock {python agmain.py } -ErrorAction Stop Invoke-Call -ScriptBlock {mypy (get-item gen/*.py) } -ErrorAction Stop \ No newline at end of file diff --git a/watch.ps1 b/watch.ps1 index 60aac37..a73dcb6 100644 --- a/watch.ps1 +++ b/watch.ps1 @@ -1 +1 @@ -watchexec --shell=powershell -ce py,lark,ag -i gen './run.ps1' +watchexec --shell=powershell -ce py,lark,ag -i gen './run.ps1'