from typing import * from lark import Transformer, Tree, Token import re from re import Pattern def unescape(s: str) -> str: q = s[0] t = "" i = 1 escaped = False while i < len(s): c = s[i] if escaped: if c == q: t += q elif c == "n": t += "\n" elif c == "t": t += "\t" if c == q: break if c == "\\": escaped = True i += 1 continue t += c i += 1 return t T = TypeVar("T") class Ast: # def __new__(cls: Type[Ast], name: str, bases: Tuple[type], namespace: Dict[str, Any]) -> Ast: # x = super().__new__(cls, name, bases, namespace) # x.id = cls.__gen() # return x id: str n = 0 @classmethod def __gen(cls, name: str = "") -> str: newid = cls.n cls.n += 1 return f"_a{newid}{name}" def __init__(self) -> None: self.id = self.__gen() class Decl: name: str class IfaceRef(str): pass class IfaceField: def __init__(self, name: str, ty: str): self.name = name self.ty = ty class Iface(Decl): def __init__(self, name: str, fields: List[IfaceField]): self.name = name self.fields = fields class Expr(Ast): def __init__(self) -> None: super().__init__() class NodeRef: pass class NodeRefByName(NodeRef, str): def __init__(self, name: str): self.name = name def __repr__(self) -> str: return f"NodeRefByName({self.name})" class NodeRegex(NodeRef): def __init__(self, pat: str): self.pat = re.compile(unescape(pat)) def __repr__(self) -> str: return f"NodeRegex({self.pat.pattern})" class Sym: pass class SymLit(Sym): def __init__(self, s: str): self.lit = unescape(s) def __repr__(self) -> str: return f"SymLit({repr(self.lit)})" class SymRename(Sym): def __init__(self, name: str, ty: NodeRef): self.name = name self.ty = ty def __repr__(self) -> str: return f"SymRename({self.name} : {self.ty})" class Equation: def __init__(self, lhs: Expr, rhs: Expr): self.lhs = lhs self.rhs = rhs def __repr__(self) -> str: return f"{self.lhs} = {self.rhs}" class Variant: def __init__(self, prod: List[Sym], equations: List[Equation]): self.prod = prod self.equations = equations def __repr__(self) -> str: return f"Variant({self.prod}, {self.equations})" class Node(Decl): def __init__(self, name: str, ifaces: List[IfaceRef], variants: List[Variant]): self.name = name self.ifaces = ifaces self.variants = variants class ExprDot(Expr): def __init__(self, left: Expr, right: str): super().__init__() self.left = left self.right = right def __repr__(self) -> str: return f"{self.left}.{self.right}" class ExprAdd(Expr): def __init__(self, left: Expr, right: Expr): super().__init__() self.left = left self.right = right def __repr__(self) -> str: return f"{self.left} + {self.right}" class ExprMul(Expr): def __init__(self, left: Expr, right: Expr): super().__init__() self.left = left self.right = right def __repr__(self) -> str: return f"{self.left} * {self.right}" class ExprCall(Expr): def __init__(self, func: Expr, args: List[Expr]): super().__init__() self.func = func self.args = args def __repr__(self) -> str: return f"{self.func}({self.args})" class ExprName(Expr): def __init__(self, name: str): super().__init__() self.name = name def __repr__(self) -> str: return f"{self.name}" class Parser(Transformer[List[Decl]]): def program(self, items: List[Decl]) -> List[Decl]: return items # interfaces def iface(self, items: List[Any]) -> Iface: [name, fields] = items return Iface(name, fields) def iface_field(self, items: List[str]) -> IfaceField: [name, ty] = items return IfaceField(name, ty) def iface_ref(self, items: List[str]) -> str: return items[0] def iface_refs(self, items: List[IfaceRef]) -> List[IfaceRef]: return items # nodes def node(self, items: List[Any]) -> Node: [name, ifaces, variants] = items return Node(name, ifaces, variants) def node_ref_name(self, items: List[str]) -> NodeRefByName: return NodeRefByName(items[0]) def node_regex(self, items: List[str]) -> NodeRegex: return NodeRegex(items[0]) # variants def variants(self, items: List[Variant]) -> List[Variant]: return items def variant(self, items: List[Any]) -> Variant: [prod, equations] = items return Variant(prod, equations) def prod(self, items: List[Sym]) -> List[Sym]: return items # symbols in productions def sym_lit(self, items: List[str]) -> Sym: return SymLit(items[0]) def sym_rename(self, items: List[Any]) -> Sym: return SymRename(items[0], items[1]) # equations def equations(self, items: List[Equation]) -> List[Equation]: return items def equation_semi(self, items: List[Equation]) -> Equation: return items[0] def equation(self, items: List[Expr]) -> Equation: return Equation(items[0], items[1]) # expr def expr_dot(self, items: List[Any]) -> Expr: [left, right] = items return ExprDot(left, right) def expr_add(self, items: List[Expr]) -> Expr: [left, right] = items return ExprAdd(left, right) def expr_mul(self, items: List[Expr]) -> Expr: [left, right] = items return ExprMul(left, right) def expr_call(self, items: List[Expr]) -> Expr: [func, args] = items # TODO: args should be a list of exprs -_ - return ExprCall(func, [args]) def expr_name(self, items: List[str]) -> Expr: return ExprName(items[0]) def sep_trail(self, items: List[Tree]) -> List[T]: return list(map(lambda it: cast(T, it), items)) def ident(self, items: List[Token]) -> str: return cast(str, items[0].value)