diff --git a/aggen.py b/aggen.py index 49c8160..c6aa791 100644 --- a/aggen.py +++ b/aggen.py @@ -4,7 +4,6 @@ import re import copy import json from collections import defaultdict -from re import Pattern from agast import * @@ -113,13 +112,13 @@ def gen(program: List[Decl]) -> GenResult: print("what_ifaces:", what_ifaces) print("what_fields:", what_fields) - # a high-level dictionary of productions; this has sub-productions - # that should be further expanded at a later step before converting - # into lark code + # a high-level dictionary of productions; this has sub-productions that + # should be further expanded at a later step before converting into lark + # code productions_hi: Dict[str, Union[str, List[str]]] = dict() - # TODO: this should probably not be inlined here, but i'll move it - # out once i get more info into the 'env' + # TODO: this should probably not be inlined here, but i'll move it out once + # i get more info into the 'env' def collect_required_thunks( env: List[Tuple[str, NodeRef]], expr: Expr ) -> Dict[str, str]: @@ -210,17 +209,16 @@ def gen(program: List[Decl]) -> GenResult: ParseEquation(prod_name, seq, v_class_name) ) - # create an environment for checking the equations based on - # the production + # create an environment for checking the equations based on the + # production env: List[Tuple[str, NodeRef]] = list() for sym in variant.prod: if isinstance(sym, SymRename): env.append((sym.name, sym.ty)) print(env) - # for each of the equations, find out what the equation is - # trying to compute, and generate a thunk corresponding to - # that value. + # for each of the equations, find out what the equation is trying + # to compute, and generate a thunk corresponding to that value. for eq in variant.equations: eq_name = gensym(f"eq_{node_desc.name}") thunk_name = gensym(f"thunk_{node_desc.name}") diff --git a/agmain.py b/agmain.py index cd49b2e..3d0a453 100644 --- a/agmain.py +++ b/agmain.py @@ -25,46 +25,7 @@ if __name__ == "__main__": if not os.path.exists("gen"): os.makedirs("gen") with open("gen/arith.py", "w") as f: - fmt_str = textwrap.dedent( - """ - # This document is generated by agtest. - - __all__ = ["parse"] - from typing import Generic, TypeVar, Optional, Callable, Dict, Any - from lark import Lark, Transformer - T = TypeVar('T') - builtins: Dict[str, Any] = {{ - "parseInt": lambda s: int(s) - }} - - class Thunk(Generic[T]): - ''' A thunk represents a value that may be computed lazily. ''' - value: Optional[T] - def __init__(self, func: Callable[[], T]): - self.func = func - self.value = None - def get(self) -> T: - if self.value is None: - self.value = self.func() - return self.value - - parser = Lark(''' - {pd} - ''', parser='lalr', start={starts}, debug=True) - - {ex} - - class Trans(Transformer[None]): {transdef} - - __agNonterminals = {ntmap} - def parse(input: str, start: Optional[str] = None) -> Any: - if start is not None: start = __agNonterminals[start] - tree = parser.parse(input, start) - trans = Trans() - res = trans.transform(tree) - return res - """ - ) + fmt_str = open("agruntime.tmpl.py", "r").read() f.write( fmt_str.format( pd=res.parser_data, diff --git a/agruntime.tmpl.py b/agruntime.tmpl.py new file mode 100644 index 0000000..f2c72f2 --- /dev/null +++ b/agruntime.tmpl.py @@ -0,0 +1,39 @@ +# This document is generated by agtest. +# type: ignore + +__all__ = ["parse"] +import re +from typing import Generic, TypeVar, Optional, Callable, Dict, Any +from lark import Lark, Transformer + +T = TypeVar('T') +builtins: Dict[str, Any] = {{ + "parseInt": lambda s: int(s) +}} + +class Thunk(Generic[T]): + ''' A thunk represents a value that may be computed lazily. ''' + value: Optional[T] + def __init__(self, func: Callable[[], T]): + self.func = func + self.value = None + def get(self) -> T: + if self.value is None: + self.value = self.func() + return self.value + +parser = Lark(''' +{pd} +''', parser='lalr', start={starts}, debug=True) + +{ex} + +class Trans(Transformer[None]): {transdef} + +__agNonterminals = {ntmap} +def parse(input: str, start: Optional[str] = None) -> Any: + if start is not None: start = __agNonterminals[start] + tree = parser.parse(input, start) + trans = Trans() + res = trans.transform(tree) + return res