diff --git a/src/frontends/lean/frontend_elaborator.cpp b/src/frontends/lean/frontend_elaborator.cpp index 2ad73794c..87d7a9323 100644 --- a/src/frontends/lean/frontend_elaborator.cpp +++ b/src/frontends/lean/frontend_elaborator.cpp @@ -48,23 +48,21 @@ expr const & get_choice(expr const & e, unsigned i) { class coercion_justification_cell : public justification_cell { context m_ctx; - expr m_app; - expr m_arg; + expr m_src; public: - coercion_justification_cell(context const & c, expr const & app, expr const & arg):m_ctx(c), m_app(app), m_arg(arg) {} + coercion_justification_cell(context const & c, expr const & src):m_ctx(c), m_src(src) {} virtual ~coercion_justification_cell() {} virtual format pp_header(formatter const & fmt, options const & opts) const { unsigned indent = get_pp_indent(opts); - format expr_fmt = fmt(m_ctx, m_arg, false, opts); + format expr_fmt = fmt(m_ctx, m_src, false, opts); format r; r += format("Coercion for"); r += nest(indent, compose(line(), expr_fmt)); return r; } virtual void get_children(buffer &) const {} - virtual expr const & get_main_expr() const { return m_arg; } + virtual expr const & get_main_expr() const { return m_src; } context const & get_context() const { return m_ctx; } - expr const & get_app() const { return m_app; } }; class overload_justification_cell : public justification_cell { @@ -88,8 +86,8 @@ public: }; -inline justification mk_coercion_justification(context const & ctx, expr const & app, expr const & arg) { - return justification(new coercion_justification_cell(ctx, app, arg)); +inline justification mk_coercion_justification(context const & ctx, expr const & e) { + return justification(new coercion_justification_cell(ctx, e)); } inline justification mk_overload_justification(context const & ctx, expr const & app) { @@ -163,7 +161,7 @@ class frontend_elaborator::imp { } expr add_coercion_mvar_app(list const & l, expr const & a, expr const & a_t, - context const & ctx, expr const & original_app, expr const & original_a) { + context const & ctx, expr const & original_a) { buffer choices; expr mvar = m_ref.m_menv.mk_metavar(ctx); for (auto p : l) { @@ -172,7 +170,7 @@ class frontend_elaborator::imp { choices.push_back(mk_lambda(g_x_name, a_t, mk_var(0))); // add indentity function std::reverse(choices.begin(), choices.end()); m_ref.m_ucs.push_back(mk_choice_constraint(ctx, mvar, choices.size(), choices.data(), - mk_coercion_justification(ctx, original_app, original_a))); + mk_coercion_justification(ctx, original_a))); return mk_app(mvar, a); } @@ -187,9 +185,13 @@ class frontend_elaborator::imp { /** \brief Try to solve overload at preprocessing time. */ +#if 0 bool choose(buffer const & f_choices, buffer const & f_choice_types, buffer & args, buffer & arg_types, context const & ctx, expr const & src) { +#else + bool choose(buffer const &, buffer const &, buffer &, buffer &, context const &, expr const &) { +#endif // TODO(Leo) return false; } @@ -235,7 +237,7 @@ class frontend_elaborator::imp { if (arg_types[i]) { list coercions = m_ref.m_frontend.get_coercions(arg_types[i]); if (coercions) - args[i] = add_coercion_mvar_app(coercions, args[i], arg_types[i], ctx, e, arg(e, i)); + args[i] = add_coercion_mvar_app(coercions, args[i], arg_types[i], ctx, arg(e, i)); } } } @@ -254,7 +256,7 @@ class frontend_elaborator::imp { list coercions = m_ref.m_frontend.get_coercions(new_a_t); if (coercions) { if (!new_f_t) { - new_a = add_coercion_mvar_app(coercions, new_a, new_a_t, ctx, e, a); + new_a = add_coercion_mvar_app(coercions, new_a, new_a_t, ctx, a); } else { expr expected = abst_domain(new_f_t); if (expected != new_a_t) { @@ -262,7 +264,7 @@ class frontend_elaborator::imp { if (c) new_a = mk_app(c, new_a); // apply coercion else - new_a = add_coercion_mvar_app(coercions, new_a, new_a_t, ctx, e, a); + new_a = add_coercion_mvar_app(coercions, new_a, new_a_t, ctx, a); } } } @@ -275,6 +277,29 @@ class frontend_elaborator::imp { } } + virtual expr visit_let(expr const & e, context const & ctx) { + lean_assert(is_let(e)); + return update_let(e, [&](expr const & t, expr const & v, expr const & b) { + expr new_t = t ? visit(t, ctx) : expr(); + expr new_v = visit(v, ctx); + if (new_t) { + expr new_v_t = get_type(new_v, ctx); + if (new_v_t && new_t != new_v_t) { + list coercions = m_ref.m_frontend.get_coercions(new_v_t); + if (coercions) { + new_v = add_coercion_mvar_app(coercions, new_v, new_v_t, ctx, v); + } + } + } + expr new_b; + { + cache::mk_scope sc(m_cache); + new_b = visit(b, extend(ctx, let_name(e), new_t, new_v)); + } + return std::make_tuple(new_t, new_v, new_b); + }); + } + virtual expr visit(expr const & e, context const & ctx) { check_interrupted(m_ref.m_interrupted); expr r = replace_visitor::visit(e, ctx);