import("util.lua")
local f, g, a, b, c, x = Consts("f, g, a, b, c, x")
local m1 = mk_metavar("m1")
local m2 = mk_metavar("m2")
local m3 = mk_metavar("m3")
local s  = fo_unify(f(m1, g(m2, c)), f(g(m2, a), g(m3, m3)))
assert(s)
assert(#s == 3)
assert(s:find(m2) == c)
assert(s:apply(f(m1, g(m2, c))) == s:apply(f(g(m2, a), g(m3, m3))))
assert(not fo_unify(f(a), g(m2)))
function must_unify(t1, t2)
   local s = fo_unify(t1, t2)
   assert(s)
   print(t1, t2, s:apply(t1))
   assert(s:apply(t1) == s:apply(t2))
end
Bool = Const("Bool")
must_unify(Type(), m1)
must_unify(fun(x, Bool, x), fun(x, Bool, m1))
must_unify(Pi(x, Bool, x), Pi(x, Bool, m1))
must_unify(Var(0), m1)
must_unify(f(m1, m2, m3), f(m1, m1, m2))
must_unify(mk_let("x", f(b), Var(0)), mk_let("y", f(m1), Var(0)))
assert(not fo_unify(mk_let("x", Bool, f(b), Var(0)), mk_let("y", f(m1), Var(0))))
assert(not fo_unify(mk_let("x", f(b), Var(0)), mk_let("y", Bool, f(m1), Var(0))))
must_unify(mk_let("x", Bool, f(b), Var(0)), mk_let("y", Bool, f(m1), Var(0)))
assert(not fo_unify(mk_let("x", Bool, f(b), Var(0)), fun(x, Bool, x)))
must_unify(iVal(10), m1)
must_unify(iVal(10), iVal(10))