164 lines
4 KiB
TypeScript
164 lines
4 KiB
TypeScript
type Type =
|
|
| { kind: "var"; name: string }
|
|
| { kind: "arrow"; left: Type; right: Type };
|
|
|
|
type Term =
|
|
| { kind: "var"; name: string }
|
|
| { kind: "abs"; name: string; ty: Type; body: Term }
|
|
| { kind: "app"; func: Term; arg: Term };
|
|
|
|
type TypeEnv = Map<string, Type>;
|
|
type TermEnv = Map<string, Term>;
|
|
|
|
function substitute(term: Term, name: string, repl: Term): Term {
|
|
switch (term.kind) {
|
|
case "var": {
|
|
if (term.name === name) return repl;
|
|
return term;
|
|
}
|
|
|
|
case "app": {
|
|
return {
|
|
kind: "app",
|
|
func: substitute(term.func, name, repl),
|
|
arg: substitute(term.arg, name, repl),
|
|
};
|
|
}
|
|
|
|
case "abs": {
|
|
// The lambda body will shadow this substitution, so throw it away
|
|
if (term.name === name) return term;
|
|
return {
|
|
kind: "abs",
|
|
ty: term.ty,
|
|
name: term.name,
|
|
body: substitute(term.body, name, repl),
|
|
};
|
|
}
|
|
}
|
|
}
|
|
|
|
export function tyEqual(t1: Type, t2: Type): boolean {
|
|
if (t1.kind === "arrow" && t2.kind === "arrow") {
|
|
return tyEqual(t1.left, t2.left) && tyEqual(t1.right, t2.right);
|
|
}
|
|
if (t1.kind === "var" && t2.kind === "var") {
|
|
return t1.name === t2.name;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
export function inferType(term: Term): Type {
|
|
function inferTypeRec(env: TypeEnv, term: Term): Type {
|
|
switch (term.kind) {
|
|
case "var": {
|
|
const entry = env.get(term.name);
|
|
if (!entry) throw new Error("unknown name");
|
|
return entry;
|
|
}
|
|
|
|
case "abs": {
|
|
const newEnv = new Map([...env.entries()]);
|
|
newEnv.set(term.name, term.ty);
|
|
return tyarr(term.ty, inferTypeRec(newEnv, term.body));
|
|
}
|
|
|
|
case "app": {
|
|
const funcTy = inferTypeRec(env, term.func);
|
|
const argTy = inferTypeRec(env, term.arg);
|
|
if (funcTy.kind !== "arrow") throw new Error("not a function");
|
|
if (!tyEqual(funcTy.left, argTy)) throw new Error("arg type mismatch");
|
|
return funcTy.right;
|
|
}
|
|
}
|
|
}
|
|
|
|
return inferTypeRec(new Map(), term);
|
|
}
|
|
|
|
function evalTerm(term: Term): Term {
|
|
function evalTermRec(env: TermEnv, term: Term): Term {
|
|
switch (term.kind) {
|
|
case "app": {
|
|
const func = evalTermRec(env, term.func);
|
|
|
|
// Require func is a function
|
|
if (func.kind !== "abs") throw new Error("applying a non-function");
|
|
|
|
const arg = evalTermRec(env, term.arg);
|
|
|
|
return substitute(func.body, func.name, arg);
|
|
}
|
|
|
|
case "var": {
|
|
const entry = env.get(term.name);
|
|
if (entry === undefined) throw new Error("unknown name");
|
|
return entry;
|
|
}
|
|
|
|
case "abs":
|
|
return term;
|
|
}
|
|
}
|
|
|
|
return evalTermRec(new Map(), term);
|
|
}
|
|
|
|
// Printing
|
|
|
|
export function prettyTerm(term: Term): string {
|
|
switch (term.kind) {
|
|
case "abs":
|
|
return `(λ ${term.name} : ${prettyType(term.ty)} -> ${prettyTerm(
|
|
term.body,
|
|
)})`;
|
|
case "app":
|
|
return `(${prettyTerm(term.func)} ${prettyTerm(term.arg)})`;
|
|
case "var":
|
|
return term.name;
|
|
}
|
|
}
|
|
export function prettyType(ty: Type): string {
|
|
switch (ty.kind) {
|
|
case "arrow":
|
|
return `(${prettyType(ty.left)} -> ${prettyType(ty.right)})`;
|
|
case "var":
|
|
return ty.name;
|
|
}
|
|
}
|
|
|
|
// Convenience
|
|
|
|
function tyvar(name: string): Type {
|
|
return { kind: "var", name };
|
|
}
|
|
function tyarr(left: Type, right: Type): Type {
|
|
return { kind: "arrow", left, right };
|
|
}
|
|
function tvar(name: string): Term {
|
|
return { kind: "var", name };
|
|
}
|
|
function tabs(name: string, ty: Type, body: Term): Term {
|
|
return { kind: "abs", name, ty, body };
|
|
}
|
|
function tapp(func: Term, arg: Term): Term {
|
|
return { kind: "app", func, arg };
|
|
}
|
|
|
|
// Examples
|
|
|
|
const Nat: Type = tyvar("Nat");
|
|
|
|
export function churchNumeralTy(): Type {
|
|
return tyarr(tyarr(Nat, Nat), tyarr(Nat, Nat));
|
|
}
|
|
|
|
export function churchEncode(num: number): Term {
|
|
if (!Number.isInteger(num)) throw new Error("not an integer");
|
|
let body = tvar("z");
|
|
let n = num;
|
|
while (--n) {
|
|
body = tapp(tvar("s"), body);
|
|
}
|
|
return tabs("s", tyarr(Nat, Nat), tabs("z", Nat, body));
|
|
}
|