diff --git a/src/kernel/abstract.h b/src/kernel/abstract.h index 036ce91a2..b02cdd750 100644 --- a/src/kernel/abstract.h +++ b/src/kernel/abstract.h @@ -27,4 +27,15 @@ inline expr abstract(expr const & s, expr const & e) { return abstract(1, &s, e) */ expr abstract_p(unsigned n, expr const * s, expr const & e); inline expr abstract_p(expr const & s, expr const & e) { return abstract_p(1, &s, e); } + +/** + \brief Create a lambda expression (lambda (x : t) b), the term b is abstracted using abstract(constant(x), b). +*/ +inline expr fun(name const & n, expr const & t, expr const & b) { return lambda(n, t, abstract(constant(n), b)); } +inline expr fun(char const * n, expr const & t, expr const & b) { return fun(name(n), t, b); } +/** + \brief Create a Pi expression (pi (x : t) b), the term b is abstracted using abstract(constant(x), b). +*/ +inline expr Fun(name const & n, expr const & t, expr const & b) { return pi(n, t, abstract(constant(n), b)); } +inline expr Fun(char const * n, expr const & t, expr const & b) { return Fun(name(n), t, b); } } diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 3a167719c..c4ff8c221 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -151,6 +151,10 @@ bool operator==(expr const & a, expr const & b) { return eq_fn()(a, b); } +bool is_arrow(expr const & t) { + return is_pi(t) && !has_free_var(abst_body(t), 0); +} + // Low-level pretty printer std::ostream & operator<<(std::ostream & out, expr const & a) { switch (a.kind()) { @@ -166,10 +170,12 @@ std::ostream & operator<<(std::ostream & out, expr const & a) { break; case expr_kind::Lambda: out << "(fun (" << abst_name(a) << " : " << abst_type(a) << ") " << abst_body(a) << ")"; break; case expr_kind::Pi: - if (has_free_var(abst_body(a), 0)) + if (!is_arrow(a)) out << "(pi (" << abst_name(a) << " : " << abst_type(a) << ") " << abst_body(a) << ")"; - else + else if (!is_arrow(abst_type(a))) out << abst_type(a) << " -> " << abst_body(a); + else + out << "(" << abst_type(a) << ") -> " << abst_body(a); break; case expr_kind::Type: { level const & l = ty_level(a); diff --git a/src/kernel/normalize.cpp b/src/kernel/normalize.cpp index 6c7e8e0b4..56d8a60cb 100644 --- a/src/kernel/normalize.cpp +++ b/src/kernel/normalize.cpp @@ -160,9 +160,18 @@ public: unsigned k = length(m_ctx); return reify(normalize(e, stack(), k), k); } + + expr operator()(expr const & e, expr const & v) { + unsigned k = length(m_ctx); + stack s = extend(stack(), normalize(v, stack(), k)); + return reify(normalize(e, s, k+1), k+1); + } }; expr normalize(expr const & e, environment const & env, context const & ctx) { return normalize_fn(env, ctx)(e); } +expr normalize(expr const & e, environment const & env, context const & ctx, expr const & v) { + return normalize_fn(env, ctx)(e, v); +} } diff --git a/src/kernel/normalize.h b/src/kernel/normalize.h index f2e9397b3..597b60f3b 100644 --- a/src/kernel/normalize.h +++ b/src/kernel/normalize.h @@ -11,5 +11,8 @@ Author: Leonardo de Moura namespace lean { class environment; +/** \brief Normalize e using the environment env and context ctx */ expr normalize(expr const & e, environment const & env, context const & ctx = context()); +/** \brief Normalize e using the environment env, context ctx, and add v to "normalization stack" */ +expr normalize(expr const & e, environment const & env, context const & ctx, expr const & v); } diff --git a/src/kernel/type_check.cpp b/src/kernel/type_check.cpp index d890d3937..49774224b 100644 --- a/src/kernel/type_check.cpp +++ b/src/kernel/type_check.cpp @@ -19,9 +19,15 @@ class infer_type_fn { return ::lean::normalize(e, m_env, ctx); } + expr normalize(expr const & e, context const & ctx, expr const & v) { + return ::lean::normalize(e, m_env, ctx, v); + } + expr lookup(context const & c, unsigned i) { context const & def_c = ::lean::lookup(c, i); - return lift_free_vars(head(def_c).get_type(), length(c) - length(def_c)); + lean_assert(length(c) >= length(def_c)); + lean_assert(length(def_c) > 0); + return lift_free_vars(head(def_c).get_type(), length(c) - (length(def_c) - 1)); } level infer_universe(expr const & t, context const & ctx) { @@ -85,7 +91,10 @@ class infer_type_fn { expr const & c = arg(e, i); expr c_t = infer_type(c, ctx); check_type(e, i, abst_type(f_t), c_t, ctx); - f_t = normalize(abst_body(f_t), extend(ctx, abst_name(f_t), c_t)); + // Remark: if f_t is an arrow, we don't have to call normalize and + // lower_free_vars + f_t = normalize(abst_body(f_t), ctx, c); + f_t = lower_free_vars(f_t, 1); i++; if (i == num) return f_t; diff --git a/src/tests/kernel/level.cpp b/src/tests/kernel/level.cpp index d4def39dd..3e6fe2a06 100644 --- a/src/tests/kernel/level.cpp +++ b/src/tests/kernel/level.cpp @@ -90,7 +90,7 @@ static void tst5() { } int main() { - // continue_on_violation(true); + continue_on_violation(true); tst1(); tst2(); tst3(); diff --git a/src/tests/kernel/type_check.cpp b/src/tests/kernel/type_check.cpp index 06cb7ae7e..397d3c88f 100644 --- a/src/tests/kernel/type_check.cpp +++ b/src/tests/kernel/type_check.cpp @@ -6,18 +6,40 @@ Author: Leonardo de Moura */ #include #include "type_check.h" +#include "abstract.h" #include "exception.h" #include "trace.h" #include "test.h" using namespace lean; +expr c(char const * n) { return constant(n); } + static void tst1() { environment env; expr t0 = type(level()); std::cout << infer_type(t0, env) << "\n"; lean_assert(infer_type(t0, env) == type(level()+1)); - expr t1 = pi("_", t0, t0); - std::cout << infer_type(t1, env) << "\n"; + expr f = pi("_", t0, t0); + std::cout << infer_type(f, env) << "\n"; + lean_assert(infer_type(f, env) == type(level()+1)); + level u = env.define_uvar("u", level() + 1); + level v = env.define_uvar("v", level() + 1); + expr g = pi("_", type(u), type(v)); + std::cout << infer_type(g, env) << "\n"; + lean_assert(infer_type(g, env) == type(max(u+1, v+1))); + std::cout << infer_type(type(u), env) << "\n"; + lean_assert(infer_type(type(u), env) == type(u+1)); + std::cout << infer_type(lambda("x", type(u), var(0)), env) << "\n"; + lean_assert(infer_type(lambda("x", type(u), var(0)), env) == pi("_", type(u), type(u))); + std::cout << infer_type(lambda("Nat", type(level()), lambda("n", var(0), var(0))), env) << "\n"; + expr nat = c("nat"); + expr T = fun("nat", type(level()), + fun("+", arrow(nat, arrow(nat, nat)), + fun("m", nat, app(c("+"), c("m"), c("m"))))); + std::cout << T << "\n"; + std::cout << infer_type(T, env) << "\n"; + std::cout << Fun("nat", type(level()), arrow(arrow(nat, arrow(nat, nat)), arrow(nat, nat))) << "\n"; + lean_assert(infer_type(T, env) == Fun("nat", type(level()), arrow(arrow(nat, arrow(nat, nat)), arrow(nat, nat)))); } static void tst2() { @@ -27,12 +49,12 @@ static void tst2() { expr t0 = type(level()); expr t1 = type(l1); expr F = - lambda("Nat", t0, - lambda("Vec", arrow(var(0), t0), - lambda("n", var(1), - lambda("len", arrow(app(var(1), var(0)), var(2)), - lambda("v", app(var(2), var(1)), - app(var(1), var(0))))))); + fun("Nat", t0, + fun("Vec", arrow(c("Nat"), t0), + fun("n", c("Nat"), + fun("len", arrow(app(c("Vec"), c("n")), c("Nat")), + fun("v", app(c("Vec"), c("n")), + app(c("len"), c("v"))))))); std::cout << F << "\n"; std::cout << infer_type(F, env) << "\n"; } @@ -42,10 +64,9 @@ static void tst2() { } int main() { - // continue_on_violation(true); + continue_on_violation(true); enable_trace("type_check"); tst1(); - return 0; tst2(); return has_violations() ? 1 : 0; }