feat(frontends/lean/elaborator): generate COERCION info

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-09-02 18:39:06 -07:00
parent 24fc89ff70
commit 974a0a4217
3 changed files with 95 additions and 8 deletions

View file

@ -582,6 +582,24 @@ public:
} }
} }
/** \brief Auxiliary function for saving information about which coercion was used by the elaborator.
It marks that coercion c was used on e.
*/
void save_coercion_info(expr const & e, expr const & c) {
if (!m_noinfo && infom() && pip()) {
if (auto p = pip()->get_pos_info(e))
m_pre_info_data.add_coercion_info(p->first, p->second, c);
}
}
/** \brief Remove coercion information associated with \c e */
void erase_coercion_info(expr const & e) {
if (!m_noinfo && infom() && pip()) {
if (auto p = pip()->get_pos_info(e))
m_pre_info_data.erase_coercion_info(p->first, p->second);
}
}
void copy_info_to_manager(substitution s) { void copy_info_to_manager(substitution s) {
if (!infom()) if (!infom())
return; return;
@ -744,12 +762,15 @@ public:
// try coercion to function-class // try coercion to function-class
optional<expr> c = get_coercion_to_fun(env(), f_type); optional<expr> c = get_coercion_to_fun(env(), f_type);
if (c) { if (c) {
save_coercion_info(f, *c);
f = mk_app(*c, f, f.get_tag()); f = mk_app(*c, f, f.get_tag());
f_type = infer_type(f, cs); f_type = infer_type(f, cs);
lean_assert(is_pi(f_type)); lean_assert(is_pi(f_type));
} else { } else {
throw_kernel_exception(env(), f, [=](formatter const & fmt) { return pp_function_expected(fmt, f); }); throw_kernel_exception(env(), f, [=](formatter const & fmt) { return pp_function_expected(fmt, f); });
} }
} else {
erase_coercion_info(f);
} }
lean_assert(is_pi(f_type)); lean_assert(is_pi(f_type));
return mk_pair(f, f_type); return mk_pair(f, f_type);
@ -770,12 +791,56 @@ public:
d_type = whnf(d_type).first; d_type = whnf(d_type).first;
expr const & d_cls = get_app_fn(d_type); expr const & d_cls = get_app_fn(d_type);
if (is_constant(d_cls)) { if (is_constant(d_cls)) {
if (auto c = get_coercion(env(), a_type, const_name(d_cls))) if (auto c = get_coercion(env(), a_type, const_name(d_cls))) {
save_coercion_info(a, *c);
return mk_app(*c, a, a.get_tag()); return mk_app(*c, a, a.get_tag());
} else {
erase_coercion_info(a);
}
} }
return a; return a;
} }
struct coercion_case_split {
elaborator & m_elab;
expr m_arg;
bool m_id; // true if identity was not tried yet
list<constraints> m_choices;
list<expr> m_coercions;
coercion_case_split(elaborator & elab, expr const & arg, list<constraints> const & choices, list<expr> const & coes):
m_elab(elab), m_arg(arg), m_id(true), m_choices(choices), m_coercions(coes) {
lean_assert(length(m_coercions) + 1 == length(m_choices));
}
optional<constraints> next() {
if (!m_choices)
return optional<constraints>();
if (m_id) {
m_id = false;
m_elab.erase_coercion_info(m_arg);
} else if (m_coercions) {
expr c = head(m_coercions);
m_coercions = tail(m_coercions);
m_elab.save_coercion_info(m_arg, c);
}
auto r = head(m_choices);
m_choices = tail(m_choices);
return optional<constraints>(r);
}
};
lazy_list<constraints> choose(std::shared_ptr<coercion_case_split> c) {
return mk_lazy_list<constraints>([=]() {
auto s = c->next();
if (s) {
return some(mk_pair(*s, choose(c)));
} else {
return lazy_list<constraints>::maybe_pair();
}
});
}
constraint mk_delayed_coercion_cnstr(expr const & m, expr const & a, expr const & a_type, constraint mk_delayed_coercion_cnstr(expr const & m, expr const & a, expr const & a_type,
justification const & j, unsigned delay_factor) { justification const & j, unsigned delay_factor) {
bool relax = m_relax_main_opaque; bool relax = m_relax_main_opaque;
@ -805,26 +870,35 @@ public:
// case-split // case-split
buffer<std::tuple<name, expr, expr>> alts; buffer<std::tuple<name, expr, expr>> alts;
get_user_coercions(env(), new_a_type, alts); get_user_coercions(env(), new_a_type, alts);
buffer<constraints> r; buffer<constraints> choices;
buffer<expr> coes;
// first alternative: no coercion // first alternative: no coercion
constraint_seq cs1 = cs + mk_eq_cnstr(mvar, a, justification(), relax); constraint_seq cs1 = cs + mk_eq_cnstr(mvar, a, justification(), relax);
r.push_back(cs1.to_list()); choices.push_back(cs1.to_list());
unsigned i = alts.size(); unsigned i = alts.size();
while (i > 0) { while (i > 0) {
--i; --i;
auto const & t = alts[i]; auto const & t = alts[i];
expr new_a = mk_app(std::get<1>(t), a, a.get_tag()); expr coe = std::get<1>(t);
expr new_a = mk_app(coe, a, a.get_tag());
coes.push_back(coe);
constraint_seq csi = cs + mk_eq_cnstr(mvar, new_a, new_a_type_jst, relax); constraint_seq csi = cs + mk_eq_cnstr(mvar, new_a, new_a_type_jst, relax);
r.push_back(csi.to_list()); choices.push_back(csi.to_list());
} }
return to_lazy(to_list(r.begin(), r.end())); return choose(std::make_shared<coercion_case_split>(*this, mvar,
to_list(choices.begin(), choices.end()),
to_list(coes.begin(), coes.end())));
} else { } else {
expr new_a = a; expr new_a = a;
expr new_d_type = tc.whnf(d_type, cs); expr new_d_type = tc.whnf(d_type, cs);
expr const & d_cls = get_app_fn(new_d_type); expr const & d_cls = get_app_fn(new_d_type);
if (is_constant(d_cls)) { if (is_constant(d_cls)) {
if (auto c = get_coercion(env(), new_a_type, const_name(d_cls))) if (auto c = get_coercion(env(), new_a_type, const_name(d_cls))) {
save_coercion_info(a, *c);
new_a = mk_app(*c, a, a.get_tag()); new_a = mk_app(*c, a, a.get_tag());
} else {
erase_coercion_info(a);
}
} }
cs += mk_eq_cnstr(mvar, new_a, new_a_type_jst, relax); cs += mk_eq_cnstr(mvar, new_a, new_a_type_jst, relax);
return lazy_list<constraints>(cs.to_list()); return lazy_list<constraints>(cs.to_list());
@ -991,6 +1065,7 @@ public:
/** \brief Make sure \c e is a type. If it is not, then try to apply coercions. */ /** \brief Make sure \c e is a type. If it is not, then try to apply coercions. */
expr ensure_type(expr const & e, constraint_seq & cs) { expr ensure_type(expr const & e, constraint_seq & cs) {
expr t = infer_type(e, cs); expr t = infer_type(e, cs);
erase_coercion_info(e);
if (is_sort(t)) if (is_sort(t))
return e; return e;
t = whnf(t, cs); t = whnf(t, cs);
@ -1007,8 +1082,10 @@ public:
} }
} }
optional<expr> c = get_coercion_to_sort(env(), t); optional<expr> c = get_coercion_to_sort(env(), t);
if (c) if (c) {
save_coercion_info(e, *c);
return mk_app(*c, e, e.get_tag()); return mk_app(*c, e, e.get_tag());
}
throw_kernel_exception(env(), e, [=](formatter const & fmt) { return pp_type_expected(fmt, e); }); throw_kernel_exception(env(), e, [=](formatter const & fmt) { return pp_type_expected(fmt, e); });
} }

View file

@ -283,6 +283,14 @@ struct info_manager::imp {
m_line_data[l].insert(mk_coercion_info(c, e)); m_line_data[l].insert(mk_coercion_info(c, e));
} }
void erase_coercion_info(unsigned l, unsigned c) {
lock_guard<mutex> lc(m_mutex);
if (m_block_new_info)
return;
synch_line(l);
m_line_data[l].erase(mk_coercion_info(c, expr()));
}
void add_symbol_info(unsigned l, unsigned c, name const & s) { void add_symbol_info(unsigned l, unsigned c, name const & s) {
lock_guard<mutex> lc(m_mutex); lock_guard<mutex> lc(m_mutex);
if (m_block_new_info) if (m_block_new_info)
@ -510,6 +518,7 @@ void info_manager::add_type_info(unsigned l, unsigned c, expr const & e) { m_ptr
void info_manager::add_synth_info(unsigned l, unsigned c, expr const & e) { m_ptr->add_synth_info(l, c, e); } void info_manager::add_synth_info(unsigned l, unsigned c, expr const & e) { m_ptr->add_synth_info(l, c, e); }
void info_manager::add_overload_info(unsigned l, unsigned c, expr const & e) { m_ptr->add_overload_info(l, c, e); } void info_manager::add_overload_info(unsigned l, unsigned c, expr const & e) { m_ptr->add_overload_info(l, c, e); }
void info_manager::add_coercion_info(unsigned l, unsigned c, expr const & e) { m_ptr->add_coercion_info(l, c, e); } void info_manager::add_coercion_info(unsigned l, unsigned c, expr const & e) { m_ptr->add_coercion_info(l, c, e); }
void info_manager::erase_coercion_info(unsigned l, unsigned c) { m_ptr->erase_coercion_info(l, c); }
void info_manager::add_symbol_info(unsigned l, unsigned c, name const & s) { m_ptr->add_symbol_info(l, c, s); } void info_manager::add_symbol_info(unsigned l, unsigned c, name const & s) { m_ptr->add_symbol_info(l, c, s); }
void info_manager::add_identifier_info(unsigned l, unsigned c, name const & full_id) { void info_manager::add_identifier_info(unsigned l, unsigned c, name const & full_id) {
m_ptr->add_identifier_info(l, c, full_id); m_ptr->add_identifier_info(l, c, full_id);

View file

@ -22,6 +22,7 @@ public:
void add_synth_info(unsigned l, unsigned c, expr const & e); void add_synth_info(unsigned l, unsigned c, expr const & e);
void add_overload_info(unsigned l, unsigned c, expr const & e); void add_overload_info(unsigned l, unsigned c, expr const & e);
void add_coercion_info(unsigned l, unsigned c, expr const & e); void add_coercion_info(unsigned l, unsigned c, expr const & e);
void erase_coercion_info(unsigned l, unsigned c);
void add_symbol_info(unsigned l, unsigned c, name const & n); void add_symbol_info(unsigned l, unsigned c, name const & n);
void add_identifier_info(unsigned l, unsigned c, name const & full_id); void add_identifier_info(unsigned l, unsigned c, name const & full_id);
void instantiate(substitution const & s); void instantiate(substitution const & s);