diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index c740d2824..dc1d465aa 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -582,6 +582,24 @@ public: } } + /** \brief Auxiliary function for saving information about which coercion was used by the elaborator. + It marks that coercion c was used on e. + */ + void save_coercion_info(expr const & e, expr const & c) { + if (!m_noinfo && infom() && pip()) { + if (auto p = pip()->get_pos_info(e)) + m_pre_info_data.add_coercion_info(p->first, p->second, c); + } + } + + /** \brief Remove coercion information associated with \c e */ + void erase_coercion_info(expr const & e) { + if (!m_noinfo && infom() && pip()) { + if (auto p = pip()->get_pos_info(e)) + m_pre_info_data.erase_coercion_info(p->first, p->second); + } + } + void copy_info_to_manager(substitution s) { if (!infom()) return; @@ -744,12 +762,15 @@ public: // try coercion to function-class optional c = get_coercion_to_fun(env(), f_type); if (c) { + save_coercion_info(f, *c); f = mk_app(*c, f, f.get_tag()); f_type = infer_type(f, cs); lean_assert(is_pi(f_type)); } else { throw_kernel_exception(env(), f, [=](formatter const & fmt) { return pp_function_expected(fmt, f); }); } + } else { + erase_coercion_info(f); } lean_assert(is_pi(f_type)); return mk_pair(f, f_type); @@ -770,12 +791,56 @@ public: d_type = whnf(d_type).first; expr const & d_cls = get_app_fn(d_type); if (is_constant(d_cls)) { - if (auto c = get_coercion(env(), a_type, const_name(d_cls))) + if (auto c = get_coercion(env(), a_type, const_name(d_cls))) { + save_coercion_info(a, *c); return mk_app(*c, a, a.get_tag()); + } else { + erase_coercion_info(a); + } } return a; } + struct coercion_case_split { + elaborator & m_elab; + expr m_arg; + bool m_id; // true if identity was not tried yet + list m_choices; + list m_coercions; + + coercion_case_split(elaborator & elab, expr const & arg, list const & choices, list const & coes): + m_elab(elab), m_arg(arg), m_id(true), m_choices(choices), m_coercions(coes) { + lean_assert(length(m_coercions) + 1 == length(m_choices)); + } + + optional next() { + if (!m_choices) + return optional(); + if (m_id) { + m_id = false; + m_elab.erase_coercion_info(m_arg); + } else if (m_coercions) { + expr c = head(m_coercions); + m_coercions = tail(m_coercions); + m_elab.save_coercion_info(m_arg, c); + } + auto r = head(m_choices); + m_choices = tail(m_choices); + return optional(r); + } + }; + + lazy_list choose(std::shared_ptr c) { + return mk_lazy_list([=]() { + auto s = c->next(); + if (s) { + return some(mk_pair(*s, choose(c))); + } else { + return lazy_list::maybe_pair(); + } + }); + } + constraint mk_delayed_coercion_cnstr(expr const & m, expr const & a, expr const & a_type, justification const & j, unsigned delay_factor) { bool relax = m_relax_main_opaque; @@ -805,26 +870,35 @@ public: // case-split buffer> alts; get_user_coercions(env(), new_a_type, alts); - buffer r; + buffer choices; + buffer coes; // first alternative: no coercion constraint_seq cs1 = cs + mk_eq_cnstr(mvar, a, justification(), relax); - r.push_back(cs1.to_list()); + choices.push_back(cs1.to_list()); unsigned i = alts.size(); while (i > 0) { --i; auto const & t = alts[i]; - expr new_a = mk_app(std::get<1>(t), a, a.get_tag()); + expr coe = std::get<1>(t); + expr new_a = mk_app(coe, a, a.get_tag()); + coes.push_back(coe); constraint_seq csi = cs + mk_eq_cnstr(mvar, new_a, new_a_type_jst, relax); - r.push_back(csi.to_list()); + choices.push_back(csi.to_list()); } - return to_lazy(to_list(r.begin(), r.end())); + return choose(std::make_shared(*this, mvar, + to_list(choices.begin(), choices.end()), + to_list(coes.begin(), coes.end()))); } else { expr new_a = a; expr new_d_type = tc.whnf(d_type, cs); expr const & d_cls = get_app_fn(new_d_type); if (is_constant(d_cls)) { - if (auto c = get_coercion(env(), new_a_type, const_name(d_cls))) + if (auto c = get_coercion(env(), new_a_type, const_name(d_cls))) { + save_coercion_info(a, *c); new_a = mk_app(*c, a, a.get_tag()); + } else { + erase_coercion_info(a); + } } cs += mk_eq_cnstr(mvar, new_a, new_a_type_jst, relax); return lazy_list(cs.to_list()); @@ -991,6 +1065,7 @@ public: /** \brief Make sure \c e is a type. If it is not, then try to apply coercions. */ expr ensure_type(expr const & e, constraint_seq & cs) { expr t = infer_type(e, cs); + erase_coercion_info(e); if (is_sort(t)) return e; t = whnf(t, cs); @@ -1007,8 +1082,10 @@ public: } } optional c = get_coercion_to_sort(env(), t); - if (c) + if (c) { + save_coercion_info(e, *c); return mk_app(*c, e, e.get_tag()); + } throw_kernel_exception(env(), e, [=](formatter const & fmt) { return pp_type_expected(fmt, e); }); } diff --git a/src/frontends/lean/info_manager.cpp b/src/frontends/lean/info_manager.cpp index 3a89b8547..6ce628698 100644 --- a/src/frontends/lean/info_manager.cpp +++ b/src/frontends/lean/info_manager.cpp @@ -283,6 +283,14 @@ struct info_manager::imp { m_line_data[l].insert(mk_coercion_info(c, e)); } + void erase_coercion_info(unsigned l, unsigned c) { + lock_guard lc(m_mutex); + if (m_block_new_info) + return; + synch_line(l); + m_line_data[l].erase(mk_coercion_info(c, expr())); + } + void add_symbol_info(unsigned l, unsigned c, name const & s) { lock_guard lc(m_mutex); if (m_block_new_info) @@ -510,6 +518,7 @@ void info_manager::add_type_info(unsigned l, unsigned c, expr const & e) { m_ptr void info_manager::add_synth_info(unsigned l, unsigned c, expr const & e) { m_ptr->add_synth_info(l, c, e); } void info_manager::add_overload_info(unsigned l, unsigned c, expr const & e) { m_ptr->add_overload_info(l, c, e); } void info_manager::add_coercion_info(unsigned l, unsigned c, expr const & e) { m_ptr->add_coercion_info(l, c, e); } +void info_manager::erase_coercion_info(unsigned l, unsigned c) { m_ptr->erase_coercion_info(l, c); } void info_manager::add_symbol_info(unsigned l, unsigned c, name const & s) { m_ptr->add_symbol_info(l, c, s); } void info_manager::add_identifier_info(unsigned l, unsigned c, name const & full_id) { m_ptr->add_identifier_info(l, c, full_id); diff --git a/src/frontends/lean/info_manager.h b/src/frontends/lean/info_manager.h index 3b0ea5977..2033efb83 100644 --- a/src/frontends/lean/info_manager.h +++ b/src/frontends/lean/info_manager.h @@ -22,6 +22,7 @@ public: void add_synth_info(unsigned l, unsigned c, expr const & e); void add_overload_info(unsigned l, unsigned c, expr const & e); void add_coercion_info(unsigned l, unsigned c, expr const & e); + void erase_coercion_info(unsigned l, unsigned c); void add_symbol_info(unsigned l, unsigned c, name const & n); void add_identifier_info(unsigned l, unsigned c, name const & full_id); void instantiate(substitution const & s);