from typing import Generic, TypeVar, Optional, Callable T = TypeVar("T") class Thunk(Generic[T]): value: Optional[T] def __init__(self, func: Callable[[], T]): self.func = func self.value = None def get(self) -> T: if self.value is None: self.value = self.func() return self.value class Node: value: Thunk[int] class Add(Node): def __init__(self, left: Node, right: Node): self.value = Thunk(lambda: left.value.get() + right.value.get()) class Mul(Node): def __init__(self, left: Node, right: Node): self.value = Thunk(lambda: left.value.get() * right.value.get()) class Lit(Node): def __init__(self, num: int): self.num = num self.value = Thunk(lambda: num) if __name__ == "__main__": tree = Add(Mul(Lit(3), Lit(4)), Lit(5)) print(tree.value.get())