diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 8b0db49e2..dfad99901 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -196,6 +196,71 @@ justification app_delayed_justification::get() { return *m_jst; } +expr type_checker::infer_constant(expr const & e, bool infer_only) { + declaration d = m_env.get(const_name(e)); + auto const & ps = d.get_univ_params(); + auto const & ls = const_levels(e); + if (length(ps) != length(ls)) + throw_kernel_exception(m_env, sstream() << "incorrect number of universe levels parameters for '" << const_name(e) << "', #" + << length(ps) << " expected, #" << length(ls) << " provided"); + if (!infer_only) { + for (level const & l : ls) + check_level(l, e); + } + return instantiate_univ_params(d.get_type(), ps, ls); +} + +expr type_checker::infer_macro(expr const & e, bool infer_only) { + buffer arg_types; + for (unsigned i = 0; i < macro_num_args(e); i++) + arg_types.push_back(infer_type_core(macro_arg(e, i), infer_only)); + expr r = macro_def(e).get_type(e, arg_types.data(), m_tc_ctx); + if (!infer_only && macro_def(e).trust_level() >= m_env.trust_lvl()) { + optional m = expand_macro(e); + if (!m) + throw_kernel_exception(m_env, "failed to expand macro", e); + expr t = infer_type_core(*m, infer_only); + simple_delayed_justification jst([=]() { return mk_macro_jst(e); }); + if (!is_def_eq(r, t, jst)) + throw_kernel_exception(m_env, g_macro_error_msg, e); + } + return r; +} + +expr type_checker::infer_lambda(expr const & e, bool infer_only) { + if (!infer_only) { + expr t = infer_type_core(binding_domain(e), infer_only); + ensure_sort_core(t, binding_domain(e)); + } + auto b = open_binding_body(e); + return mk_pi(binding_name(e), binding_domain(e), abstract_local(infer_type_core(b.first, infer_only), b.second), binding_info(e)); +} + +expr type_checker::infer_pi(expr const & e, bool infer_only) { + expr t1 = ensure_sort_core(infer_type_core(binding_domain(e), infer_only), binding_domain(e)); + auto b = open_binding_body(e); + expr t2 = ensure_sort_core(infer_type_core(b.first, infer_only), binding_body(e)); + if (m_env.impredicative()) + return mk_sort(mk_imax(sort_level(t1), sort_level(t2))); + else + return mk_sort(mk_max(sort_level(t1), sort_level(t2))); +} + +expr type_checker::infer_app(expr const & e, bool infer_only) { + expr f_type = ensure_pi_core(infer_type_core(app_fn(e), infer_only), app_fn(e)); + if (!infer_only) { + expr a_type = infer_type_core(app_arg(e), infer_only); + app_delayed_justification jst(e, f_type, a_type); + if (!is_def_eq(a_type, binding_domain(f_type), jst)) { + throw_kernel_exception(m_env, app_arg(e), + [=](formatter const & fmt) { + return pp_app_type_mismatch(fmt, e, binding_domain(f_type), a_type); + }); + } + } + return instantiate(binding_body(f_type), app_arg(e)); +} + /** \brief Return type of expression \c e, if \c infer_only is false, then it also check whether \c e is type correct or not. @@ -216,80 +281,19 @@ expr type_checker::infer_type_core(expr const & e, bool infer_only) { expr r; switch (e.kind()) { - case expr_kind::Local: case expr_kind::Meta: - r = mlocal_type(e); - break; + case expr_kind::Local: case expr_kind::Meta: r = mlocal_type(e); break; case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Sort: - if (!infer_only) - check_level(sort_level(e), e); + if (!infer_only) check_level(sort_level(e), e); r = mk_sort(mk_succ(sort_level(e))); break; - case expr_kind::Constant: { - declaration d = m_env.get(const_name(e)); - auto const & ps = d.get_univ_params(); - auto const & ls = const_levels(e); - if (length(ps) != length(ls)) - throw_kernel_exception(m_env, sstream() << "incorrect number of universe levels parameters for '" << const_name(e) << "', #" - << length(ps) << " expected, #" << length(ls) << " provided"); - if (!infer_only) { - for (level const & l : ls) - check_level(l, e); - } - r = instantiate_univ_params(d.get_type(), ps, ls); - break; + case expr_kind::Constant: r = infer_constant(e, infer_only); break; + case expr_kind::Macro: r = infer_macro(e, infer_only); break; + case expr_kind::Lambda: r = infer_lambda(e, infer_only); break; + case expr_kind::Pi: r = infer_pi(e, infer_only); break; + case expr_kind::App: r = infer_app(e, infer_only); break; } - case expr_kind::Macro: { - buffer arg_types; - for (unsigned i = 0; i < macro_num_args(e); i++) - arg_types.push_back(infer_type_core(macro_arg(e, i), infer_only)); - r = macro_def(e).get_type(e, arg_types.data(), m_tc_ctx); - if (!infer_only && macro_def(e).trust_level() >= m_env.trust_lvl()) { - optional m = expand_macro(e); - if (!m) - throw_kernel_exception(m_env, "failed to expand macro", e); - expr t = infer_type_core(*m, infer_only); - simple_delayed_justification jst([=]() { return mk_macro_jst(e); }); - if (!is_def_eq(r, t, jst)) - throw_kernel_exception(m_env, g_macro_error_msg, e); - } - break; - } - case expr_kind::Lambda: { - if (!infer_only) { - expr t = infer_type_core(binding_domain(e), infer_only); - ensure_sort_core(t, binding_domain(e)); - } - auto b = open_binding_body(e); - r = mk_pi(binding_name(e), binding_domain(e), abstract_local(infer_type_core(b.first, infer_only), b.second), binding_info(e)); - break; - } - case expr_kind::Pi: { - expr t1 = ensure_sort_core(infer_type_core(binding_domain(e), infer_only), binding_domain(e)); - auto b = open_binding_body(e); - expr t2 = ensure_sort_core(infer_type_core(b.first, infer_only), binding_body(e)); - if (m_env.impredicative()) - r = mk_sort(mk_imax(sort_level(t1), sort_level(t2))); - else - r = mk_sort(mk_max(sort_level(t1), sort_level(t2))); - break; - } - case expr_kind::App: { - expr f_type = ensure_pi_core(infer_type_core(app_fn(e), infer_only), app_fn(e)); - if (!infer_only) { - expr a_type = infer_type_core(app_arg(e), infer_only); - app_delayed_justification jst(e, f_type, a_type); - if (!is_def_eq(a_type, binding_domain(f_type), jst)) { - throw_kernel_exception(m_env, app_arg(e), - [=](formatter const & fmt) { - return pp_app_type_mismatch(fmt, e, binding_domain(f_type), a_type); - }); - } - } - r = instantiate(binding_body(f_type), app_arg(e)); - break; - }} if (m_memoize) m_infer_type_cache[infer_only].insert(mk_pair(e, r)); diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index cd555cdc4..d0860a153 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -96,6 +96,11 @@ class type_checker { expr ensure_pi_core(expr e, expr const & s); justification mk_macro_jst(expr const & e); void check_level(level const & l, expr const & s); + expr infer_constant(expr const & e, bool infer_only); + expr infer_macro(expr const & e, bool infer_only); + expr infer_lambda(expr const & e, bool infer_only); + expr infer_pi(expr const & e, bool infer_only); + expr infer_app(expr const & e, bool infer_only); expr infer_type_core(expr const & e, bool infer_only); expr infer_type(expr const & e); extension_context & get_extension() { return m_tc_ctx; }