diff --git a/src/library/blast/blast.cpp b/src/library/blast/blast.cpp index bc4cdd427..33bcfe6fd 100644 --- a/src/library/blast/blast.cpp +++ b/src/library/blast/blast.cpp @@ -247,6 +247,9 @@ public: projection_info const * get_projection_info(name const & n) const { return m_projection_info.find(n); } + + name mk_fresh_local_name() { + } }; LEAN_THREAD_PTR(context, g_context); @@ -334,6 +337,11 @@ extension_context & ext_ctx() { lean_assert(g_ext_context); return *g_ext_context; } + +name mk_fresh_local_name() { + lean_assert(g_ext_context); + return g_ext_context->mk_fresh_name(); +} } optional blast_goal(environment const & env, io_state const & ios, list const & ls, list const & ds, goal const & g) { diff --git a/src/library/blast/blast_context.h b/src/library/blast/blast_context.h index 2abbdc28c..e7e241b91 100644 --- a/src/library/blast/blast_context.h +++ b/src/library/blast/blast_context.h @@ -24,6 +24,8 @@ io_state const & ios(); state & curr_state(); /** \brief Return the thead local extension context associated with the blast tactic. */ extension_context & ext_ctx(); +/** \brief Return a thread local fresh name meant to be used to name local constants. */ +name mk_fresh_local_name(); /** \brief Return true iff the given constant name is marked as reducible in env() */ bool is_reducible(name const & n); /** \brief Return a nonnull projection_info object if \c n is the name of a projection in env() */ diff --git a/src/library/blast/expr.cpp b/src/library/blast/expr.cpp index 4db9ed6a9..2c6ecf6e4 100644 --- a/src/library/blast/expr.cpp +++ b/src/library/blast/expr.cpp @@ -160,9 +160,9 @@ bool has_mref(expr const & e) { return lean::has_expr_metavar(e); } -expr mk_local(unsigned idx, expr const & t) { +expr mk_local(name const & n, name const & pp_n, expr const & t, binder_info const & bi) { lean_assert(is_cached(t)); - return lean::mk_local(name(*g_prefix, idx), t); + return lean::mk_local(n, pp_n, t, bi); } bool is_local(expr const & e) { diff --git a/src/library/blast/expr.h b/src/library/blast/expr.h index 01bb10275..a3ae1fc58 100644 --- a/src/library/blast/expr.h +++ b/src/library/blast/expr.h @@ -42,7 +42,7 @@ expr mk_href(unsigned idx); expr mk_mref(unsigned idx); expr mk_sort(level const & l); expr mk_constant(name const & n, levels const & ls); -expr mk_local(unsigned idx, expr const & t); +expr mk_local(name const & n, name const & pp_n, expr const & t, binder_info const & bi); expr mk_app(expr const & f, expr const & a); expr mk_app(expr const & f, unsigned num_args, expr const * args); expr mk_app(unsigned num_args, expr const * args); diff --git a/src/library/blast/infer_type.cpp b/src/library/blast/infer_type.cpp index e5cb2a8a8..4bc85bc14 100644 --- a/src/library/blast/infer_type.cpp +++ b/src/library/blast/infer_type.cpp @@ -6,8 +6,10 @@ Author: Leonardo de Moura */ #include "util/interrupt.h" #include "kernel/instantiate.h" +#include "kernel/abstract.h" #include "library/blast/infer_type.h" #include "library/blast/blast_context.h" +#include "library/blast/blast_exception.h" namespace lean { namespace blast { @@ -108,8 +110,144 @@ bool is_def_eq(expr const & e1, expr const & e2) { return e1 == e2; } +static expr infer_constant(expr const & e) { + declaration d = env().get(const_name(e)); + auto const & ps = d.get_univ_params(); + auto const & ls = const_levels(e); + if (length(ps) != length(ls)) + throw blast_exception("infer type failed, incorrect number of universe levels", e); + return instantiate_type_univ_params(d, ls); +} + +static expr infer_macro(expr const & e) { + auto def = macro_def(e); + bool infer_only = true; + // Remark: we are ignoring constraints generated by the macro definition. + return def.check_type(e, ext_ctx(), infer_only).first; +} + +static expr infer_lambda(expr e) { + buffer es, ds, ls; + while (is_lambda(e)) { + es.push_back(e); + ds.push_back(binding_domain(e)); + expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data()); + expr l = blast::mk_local(mk_fresh_local_name(), binding_name(e), d, binding_info(e)); + ls.push_back(l); + e = binding_body(e); + } + expr t = infer_type(instantiate_rev(e, ls.size(), ls.data())); + expr r = abstract_locals(t, ls.size(), ls.data()); + unsigned i = es.size(); + while (i > 0) { + --i; + r = blast::mk_pi(binding_name(es[i]), ds[i], r, binding_info(es[i])); + } + return r; +} + +/** \brief Make sure \c e is a sort, if it is not throw an exception using \c ref as a reference */ +static void ensure_sort(expr const & e, expr const & ref) { + // Remark: for simplicity reasons, we just fail if \c e is not a sort. + if (!is_sort(e)) + throw blast_exception("infer type failed, sort expected", ref); +} + +static expr infer_pi(expr const & e0) { + buffer ls; + buffer us; + expr e = e0; + while (is_pi(e)) { + expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data()); + expr d_type = whnf(infer_type(d)); + ensure_sort(d_type, e0); + us.push_back(sort_level(d_type)); + expr l = blast::mk_local(mk_fresh_local_name(), binding_name(e), d, binding_info(e)); + ls.push_back(l); + e = binding_body(e); + } + e = instantiate_rev(e, ls.size(), ls.data()); + expr e_type = whnf(infer_type(e)); + ensure_sort(e_type, e0); + level r = sort_level(e_type); + unsigned i = ls.size(); + bool imp = env().impredicative(); + while (i > 0) { + --i; + r = imp ? blast::mk_imax(us[i], r) : blast::mk_max(us[i], r); + } + return blast::mk_sort(r); +} + +/** \brief Make sure \c e is a Pi-expression, if it is not throw an exception using \c ref as a reference */ +static void ensure_pi(expr const & e, expr const & ref) { + // Remark: for simplicity reasons, we just fail if \c e is not a Pi. + if (!is_pi(e)) + throw blast_exception("infer type failed, Pi expected", ref); +} + +static expr infer_app(expr const & e) { + buffer args; + expr const & f = get_app_args(e, args); + expr f_type = infer_type(f); + unsigned j = 0; + unsigned nargs = args.size(); + for (unsigned i = 0; i < nargs; i++) { + if (is_pi(f_type)) { + f_type = binding_body(f_type); + } else { + f_type = whnf(instantiate_rev(f_type, i-j, args.data()+j)); + ensure_pi(f_type, e); + f_type = binding_body(f_type); + j = i; + } + } + return instantiate_rev(f_type, nargs-j, args.data()+j); +} + expr infer_type(expr const & e) { - // TODO(Leo) - return e; + lean_assert(!is_var(e)); + lean_assert(closed(e)); + check_system("infer_type"); + + expr r; + switch (e.kind()) { + case expr_kind::Local: + if (is_href(e)) { + if (hypothesis const * h = state().get_main_branch().get(e)) { + r = h->get_type(); + } else { + throw blast_exception("infer type failed, unknown hypothesis", e); + } + } else { + r = mlocal_type(e); + } + break; + case expr_kind::Meta: + r = mlocal_type(e); + break; + case expr_kind::Var: + lean_unreachable(); // LCOV_EXCL_LINE + case expr_kind::Sort: + r = blast::mk_sort(blast::mk_succ(sort_level(e))); + break; + case expr_kind::Constant: + r = infer_constant(e); + break; + case expr_kind::Macro: + r = infer_macro(e); + break; + case expr_kind::Lambda: + r = infer_lambda(e); + break; + case expr_kind::Pi: + r = infer_pi(e); + break; + case expr_kind::App: + r = infer_app(e); + break; + } + // TODO(Leo): cache results if we have performance problems + return r; } }} diff --git a/src/library/blast/state.cpp b/src/library/blast/state.cpp index 753f441ff..da509da7c 100644 --- a/src/library/blast/state.cpp +++ b/src/library/blast/state.cpp @@ -11,7 +11,7 @@ Author: Leonardo de Moura namespace lean { namespace blast { -state::state():m_next_mref_index(0) {} +state::state():m_next_uvar_index(0), m_next_mref_index(0) {} /** \brief Mark that hypothesis h with index hidx is fixed by the meta-variable midx. That is, `h` occurs in the type of `midx`. */ diff --git a/src/library/blast/state.h b/src/library/blast/state.h index 0a47f0827..0f852b5cf 100644 --- a/src/library/blast/state.h +++ b/src/library/blast/state.h @@ -28,12 +28,15 @@ public: class state { friend class context; typedef metavar_idx_map metavar_decls; - typedef metavar_idx_map assignment; + typedef metavar_idx_map eassignment; + typedef metavar_idx_map uassignment; typedef hypothesis_idx_map fixed_by; - unsigned m_next_mref_index; - metavar_decls m_metavar_decls; - assignment m_assignment; - branch m_main; + unsigned m_next_uvar_index; // index of the next universe metavariable + uassignment m_uassignment; + unsigned m_next_mref_index; // index of the next metavariable + metavar_decls m_metavar_decls; + eassignment m_eassignment; + branch m_main; // In the following mapping, each entry (h -> {m_1 ... m_n}) means that hypothesis `h` cannot be cleared // in any branch where the metavariables m_1 ... m_n have not been replaced with the values assigned to them. // That is, to be able to clear `h` in a branch `B`, we first need to check whether it @@ -68,6 +71,9 @@ public: return m_main.add_hypothesis(n, type, value, jst); } + branch & get_main_branch() { return m_main; } + branch const & get_main_branch() const { return m_main; } + /** \brief Add a new hypothesis to the main branch */ expr add_hypothesis(expr const & type, optional const & value, optional const & jst) { return m_main.add_hypothesis(type, value, jst);