diff --git a/src/frontends/lean/coercion_elaborator.cpp b/src/frontends/lean/coercion_elaborator.cpp index e35f376f3..205507235 100644 --- a/src/frontends/lean/coercion_elaborator.cpp +++ b/src/frontends/lean/coercion_elaborator.cpp @@ -21,6 +21,23 @@ coercion_elaborator::coercion_elaborator(coercion_info_manager & info, expr cons lean_assert(use_id || length(m_coercions) == length(m_choices)); } +list get_coercions_from_to(type_checker & tc, expr const & from_type, expr const & to_type, constraint_seq & cs) { + constraint_seq new_cs; + expr whnf_to_type = tc.whnf(to_type, new_cs); + expr const & fn = get_app_fn(whnf_to_type); + list r; + if (is_constant(fn)) { + r = get_coercions(tc.env(), from_type, const_name(fn)); + } else if (is_pi(whnf_to_type)) { + r = get_coercions_to_fun(tc.env(), from_type); + } else if (is_sort(whnf_to_type)) { + r = get_coercions_to_sort(tc.env(), from_type); + } + if (r) + cs += new_cs; + return r; +} + optional coercion_elaborator::next() { if (!m_choices) return optional(); @@ -65,8 +82,8 @@ constraint mk_coercion_cnstr(type_checker & tc, coercion_info_manager & infom, new_a_type = tc.whnf(new_a_type, cs); if (is_meta(d_type)) { // case-split - buffer> alts; - get_user_coercions(tc.env(), new_a_type, alts); + buffer> alts; + get_coercions_from(tc.env(), new_a_type, alts); buffer choices; buffer coes; // first alternative: no coercion @@ -86,33 +103,24 @@ constraint mk_coercion_cnstr(type_checker & tc, coercion_info_manager & infom, to_list(choices.begin(), choices.end()), to_list(coes.begin(), coes.end()))); } else { - expr new_d_type = tc.whnf(d_type, cs); - expr const & d_cls = get_app_fn(new_d_type); - if (is_constant(d_cls)) { - list coes = get_coercions(tc.env(), new_a_type, const_name(d_cls)); - if (is_nil(coes)) { - expr new_a = a; - infom.erase_coercion_info(a); - cs += mk_eq_cnstr(meta, new_a, new_a_type_jst, relax); - return lazy_list(cs.to_list()); - } else if (is_nil(tail(coes))) { - expr new_a = copy_tag(a, mk_app(head(coes), a)); - infom.save_coercion_info(a, new_a); - cs += mk_eq_cnstr(meta, new_a, new_a_type_jst, relax); - return lazy_list(cs.to_list()); - } else { - list choices = map2(coes, [&](expr const & coe) { - expr new_a = copy_tag(a, mk_app(coe, a)); - constraint c = mk_eq_cnstr(meta, new_a, new_a_type_jst, relax); - return (cs + c).to_list(); - }); - return choose(std::make_shared(infom, meta, choices, coes, false)); - } - } else { + list coes = get_coercions_from_to(tc, new_a_type, d_type, cs); + if (is_nil(coes)) { expr new_a = a; infom.erase_coercion_info(a); cs += mk_eq_cnstr(meta, new_a, new_a_type_jst, relax); return lazy_list(cs.to_list()); + } else if (is_nil(tail(coes))) { + expr new_a = copy_tag(a, mk_app(head(coes), a)); + infom.save_coercion_info(a, new_a); + cs += mk_eq_cnstr(meta, new_a, new_a_type_jst, relax); + return lazy_list(cs.to_list()); + } else { + list choices = map2(coes, [&](expr const & coe) { + expr new_a = copy_tag(a, mk_app(coe, a)); + constraint c = mk_eq_cnstr(meta, new_a, new_a_type_jst, relax); + return (cs + c).to_list(); + }); + return choose(std::make_shared(infom, meta, choices, coes, false)); } } }; diff --git a/src/frontends/lean/coercion_elaborator.h b/src/frontends/lean/coercion_elaborator.h index f59759410..3cd419d21 100644 --- a/src/frontends/lean/coercion_elaborator.h +++ b/src/frontends/lean/coercion_elaborator.h @@ -51,4 +51,6 @@ pair mk_coercion_elaborator( pair coercions_to_choice(coercion_info_manager & infom, local_context & ctx, list const & coes, expr const & a, justification const & j, bool relax); + +list get_coercions_from_to(type_checker & tc, expr const & from_type, expr const & to_type, constraint_seq & cs); } diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index 014da4042..b87307aae 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -434,37 +434,40 @@ public: return is_constant(a_cls) && ::lean::has_coercions_from(env(), const_name(a_cls)); } - bool has_coercions_to(expr const & d_type) { - expr const & d_cls = get_app_fn(whnf(d_type).first); - return is_constant(d_cls) && ::lean::has_coercions_to(env(), const_name(d_cls)); + bool has_coercions_to(expr d_type) { + d_type = whnf(d_type).first; + expr const & fn = get_app_fn(d_type); + if (is_constant(fn)) + return ::lean::has_coercions_to(env(), const_name(fn)); + else if (is_pi(d_type)) + return ::lean::has_coercions_to_fun(env()); + else if (is_sort(d_type)) + return ::lean::has_coercions_to_sort(env()); + else + return false; } expr apply_coercion(expr const & a, expr a_type, expr d_type) { a_type = whnf(a_type).first; d_type = whnf(d_type).first; - expr const & d_cls = get_app_fn(d_type); - if (is_constant(d_cls)) { - list coes = get_coercions(env(), a_type, const_name(d_cls)); - if (is_nil(coes)) { - erase_coercion_info(a); - return a; - } else if (is_nil(tail(coes))) { - expr r = mk_app(head(coes), a, a.get_tag()); - save_coercion_info(a, r); - return r; - } else { - for (expr const & coe : coes) { - expr r = mk_app(coe, a, a.get_tag()); - expr r_type = infer_type(r).first; - if (m_tc[m_relax_main_opaque]->is_def_eq(r_type, d_type).first) { - save_coercion_info(a, r); - return r; - } - } - erase_coercion_info(a); - return a; - } + constraint_seq aux_cs; + list coes = get_coercions_from_to(*m_tc[m_relax_main_opaque], a_type, d_type, aux_cs); + if (is_nil(coes)) { + erase_coercion_info(a); + return a; + } else if (is_nil(tail(coes))) { + expr r = mk_app(head(coes), a, a.get_tag()); + save_coercion_info(a, r); + return r; } else { + for (expr const & coe : coes) { + expr r = mk_app(coe, a, a.get_tag()); + expr r_type = infer_type(r).first; + if (m_tc[m_relax_main_opaque]->is_def_eq(r_type, d_type).first) { + save_coercion_info(a, r); + return r; + } + } erase_coercion_info(a); return a; } diff --git a/src/library/coercion.cpp b/src/library/coercion.cpp index 211d54836..74c2a965b 100644 --- a/src/library/coercion.cpp +++ b/src/library/coercion.cpp @@ -16,28 +16,15 @@ Author: Leonardo de Moura #include "library/scoped_ext.h" namespace lean { -enum class coercion_class_kind { User, Sort, Fun }; - -/** - \brief A coercion is a mapping between classes. - We support three kinds of classes: User, Sort, Function. -*/ -class coercion_class { - coercion_class_kind m_kind; - name m_name; // relevant only if m_kind == User - coercion_class(coercion_class_kind k, name const & n = name()):m_kind(k), m_name(n) {} -public: - coercion_class():m_kind(coercion_class_kind::Sort) {} - static coercion_class mk_user(name n) { return coercion_class(coercion_class_kind::User, n); } - static coercion_class mk_sort() { return coercion_class(coercion_class_kind::Sort); } - static coercion_class mk_fun() { return coercion_class(coercion_class_kind::Fun); } - friend bool operator==(coercion_class const & c1, coercion_class const & c2) { - return c1.m_kind == c2.m_kind && c1.m_name == c2.m_name; - } - friend bool operator!=(coercion_class const & c1, coercion_class const & c2) { return !(c1 == c2); } - coercion_class_kind kind() const { return m_kind; } - name get_name() const { return m_name; } -}; +coercion_class coercion_class::mk_user(name n) { return coercion_class(coercion_class_kind::User, n); } +coercion_class coercion_class::mk_sort() { return coercion_class(coercion_class_kind::Sort); } +coercion_class coercion_class::mk_fun() { return coercion_class(coercion_class_kind::Fun); } +bool operator==(coercion_class const & c1, coercion_class const & c2) { + return c1.m_kind == c2.m_kind && c1.m_name == c2.m_name; +} +bool operator!=(coercion_class const & c1, coercion_class const & c2) { + return !(c1 == c2); +} std::ostream & operator<<(std::ostream & out, coercion_class const & cls) { switch (cls.kind()) { @@ -415,6 +402,18 @@ bool has_coercions_to(environment const & env, name const & D) { return it && !is_nil(*it); } +bool has_coercions_to_sort(environment const & env) { + coercion_state const & ext = coercion_ext::get_state(env); + auto it = ext.m_to.find(coercion_class::mk_sort()); + return it && !is_nil(*it); +} + +bool has_coercions_to_fun(environment const & env) { + coercion_state const & ext = coercion_ext::get_state(env); + auto it = ext.m_to.find(coercion_class::mk_fun()); + return it && !is_nil(*it); +} + bool has_coercions_from(environment const & env, name const & C) { coercion_state const & ext = coercion_ext::get_state(env); return ext.m_coercion_info.contains(C); @@ -465,7 +464,7 @@ list get_coercions_to_fun(environment const & env, expr const & C) { return get_coercions(env, C, coercion_class::mk_fun()); } -bool get_user_coercions(environment const & env, expr const & C, buffer> & result) { +bool get_coercions_from(environment const & env, expr const & C, buffer> & result) { buffer args; expr const & C_fn = get_app_rev_args(C, args); if (!is_constant(C_fn)) @@ -476,15 +475,14 @@ bool get_user_coercions(environment const & env, expr const & C, buffer> r; - get_user_coercions(to_environment(L, 1), to_expr(L, 2), r); +static int get_coercions_from(lua_State * L) { + buffer> r; + get_coercions_from(to_environment(L, 1), to_expr(L, 2), r); lua_newtable(L); int i = 1; for (auto p : r) { lua_newtable(L); - push_name(L, std::get<0>(p)); + coercion_class c = std::get<0>(p); + push_integer(L, static_cast(c.kind())); lua_rawseti(L, -2, 1); - push_expr(L, std::get<1>(p)); + if (c.kind() == coercion_class_kind::User) { + push_name(L, c.get_name()); + } else { + push_nil(L); + } lua_rawseti(L, -2, 2); - push_expr(L, std::get<2>(p)); + push_expr(L, std::get<1>(p)); lua_rawseti(L, -2, 3); + push_expr(L, std::get<2>(p)); + lua_rawseti(L, -2, 4); lua_rawseti(L, -2, i); i = i + 1; } @@ -616,10 +621,10 @@ void open_coercion(lua_State * L) { SET_GLOBAL_FUN(add_coercion, "add_coercion"); SET_GLOBAL_FUN(is_coercion, "is_coercion"); SET_GLOBAL_FUN(has_coercions_from, "has_coercions_from"); - SET_GLOBAL_FUN(get_coercions, "get_coercions"); - SET_GLOBAL_FUN(get_coercions_to_sort, "get_coercions_to_sort"); - SET_GLOBAL_FUN(get_coercions_to_fun, "get_coercions_to_fun"); - SET_GLOBAL_FUN(get_user_coercions, "get_user_coercions"); + SET_GLOBAL_FUN(get_coercions, "get_coercions"); + SET_GLOBAL_FUN(get_coercions_to_sort, "get_coercions_to_sort"); + SET_GLOBAL_FUN(get_coercions_to_fun, "get_coercions_to_fun"); + SET_GLOBAL_FUN(get_coercions_from, "get_coercions_from"); SET_GLOBAL_FUN(for_each_coercion_user, "for_each_coercion_user"); SET_GLOBAL_FUN(for_each_coercion_sort, "for_each_coercion_sort"); SET_GLOBAL_FUN(for_each_coercion_fun, "for_each_coercion_fun"); diff --git a/src/library/coercion.h b/src/library/coercion.h index 5ebe58d21..0de70ffc8 100644 --- a/src/library/coercion.h +++ b/src/library/coercion.h @@ -12,6 +12,26 @@ Author: Leonardo de Moura #include "library/io_state.h" namespace lean { +enum class coercion_class_kind { User, Sort, Fun }; +/** + \brief A coercion is a mapping between classes. + We support three kinds of classes: User, Sort, Function. +*/ +class coercion_class { + coercion_class_kind m_kind; + name m_name; // relevant only if m_kind == User + coercion_class(coercion_class_kind k, name const & n = name()):m_kind(k), m_name(n) {} +public: + coercion_class():m_kind(coercion_class_kind::Sort) {} + static coercion_class mk_user(name n); + static coercion_class mk_sort(); + static coercion_class mk_fun(); + friend bool operator==(coercion_class const & c1, coercion_class const & c2); + friend bool operator!=(coercion_class const & c1, coercion_class const & c2); + coercion_class_kind kind() const { return m_kind; } + name get_name() const { return m_name; } +}; + /** \brief Add an new coercion in the given environment. @@ -51,6 +71,8 @@ bool has_coercions_from(environment const & env, name const & C); bool has_coercions_from(environment const & env, expr const & C); /** \brief Return true iff the given environment has coercions to a user-class named \c D. */ bool has_coercions_to(environment const & env, name const & D); +bool has_coercions_to_sort(environment const & env); +bool has_coercions_to_fun(environment const & env); /** \brief Return a coercion (if it exists) from (C_name.{l1 lk} t_1 ... t_n) to the class named D. The coercion is a unary function that takes a term of type (C_name.{l1 lk} t_1 ... t_n) and returns @@ -60,13 +82,13 @@ list get_coercions(environment const & env, expr const & C, name const & D list get_coercions_to_sort(environment const & env, expr const & C); list get_coercions_to_fun(environment const & env, expr const & C); /** - \brief Return all user coercions C >-> D for the type C of the form (C_name.{l1 ... lk} t_1 ... t_n) - The result is a pair (user-class D, coercion, coercion type), and is stored in the result buffer \c result. + \brief Return all coercions C >-> D for the type C of the form (C_name.{l1 ... lk} t_1 ... t_n) + The result is a tuple (class D, coercion, coercion type), and is stored in the result buffer \c result. The Boolean result is true if at least one pair is added to \c result. \remark The most recent coercions occur first. */ -bool get_user_coercions(environment const & env, expr const & C, buffer> & result); +bool get_coercions_from(environment const & env, expr const & C, buffer> & result); typedef std::function coercion_user_fn; typedef std::function coercion_sort_fn; diff --git a/tests/lean/run/coe13.lean b/tests/lean/run/coe13.lean new file mode 100644 index 000000000..98c159623 --- /dev/null +++ b/tests/lean/run/coe13.lean @@ -0,0 +1,30 @@ +import data.nat +open nat + +inductive functor (A B : Type) := +mk : (A → B) → functor A B + +definition functor.to_fun [coercion] {A B : Type} (f : functor A B) : A → B := +functor.rec (λ f, f) f + +inductive struct := +mk : Π (A : Type), (A → A → Prop) → struct + +definition struct.to_sort [coercion] (s : struct) : Type := +struct.rec (λA r, A) s + +definition g (f : nat → nat) (a : nat) := f a + +variable f : functor nat nat + +check g (functor.to_fun f) 0 + +check g f 0 + +definition id (A : Type) (a : A) := a + +variable S : struct +variable a : S + +check id (struct.to_sort S) a +check id S a diff --git a/tests/lean/run/coe14.lean b/tests/lean/run/coe14.lean new file mode 100644 index 000000000..98c159623 --- /dev/null +++ b/tests/lean/run/coe14.lean @@ -0,0 +1,30 @@ +import data.nat +open nat + +inductive functor (A B : Type) := +mk : (A → B) → functor A B + +definition functor.to_fun [coercion] {A B : Type} (f : functor A B) : A → B := +functor.rec (λ f, f) f + +inductive struct := +mk : Π (A : Type), (A → A → Prop) → struct + +definition struct.to_sort [coercion] (s : struct) : Type := +struct.rec (λA r, A) s + +definition g (f : nat → nat) (a : nat) := f a + +variable f : functor nat nat + +check g (functor.to_fun f) 0 + +check g f 0 + +definition id (A : Type) (a : A) := a + +variable S : struct +variable a : S + +check id (struct.to_sort S) a +check id S a diff --git a/tests/lean/run/coe15.lean b/tests/lean/run/coe15.lean new file mode 100644 index 000000000..0b76fe21b --- /dev/null +++ b/tests/lean/run/coe15.lean @@ -0,0 +1,20 @@ +import data.nat +open nat + +inductive functor (A B : Type) := +mk : (A → B) → functor A B + +definition functor.to_fun [coercion] {A B : Type} (f : functor A B) : A → B := +functor.rec (λ f, f) f + +inductive struct := +mk : Π (A : Type), (A → A → Prop) → struct + +definition struct.to_sort [coercion] (s : struct) : Type := +struct.rec (λA r, A) s + +definition g (f : nat → nat) (a : nat) := f a + +check + λ f, + (g f 0) = 0 ∧ (functor.to_fun f) 0 = 0 diff --git a/tests/lua/coe1.lua b/tests/lua/coe1.lua index 6f64eee7a..ce8f34621 100644 --- a/tests/lua/coe1.lua +++ b/tests/lua/coe1.lua @@ -84,7 +84,7 @@ assert(not has_coercions_from(env2, Const("vec", {1})(nat))) assert(not has_coercions_from(env2, Const("vec")(nat, one))) print("Coercions (vec nat one): ") -cs = get_user_coercions(env2, Const("vec", {1})(nat, one)) +cs = get_coercions_from(env2, Const("vec", {1})(nat, one)) for i = 1, #cs do - print(tostring(cs[i][1]) .. " : " .. tostring(cs[i][3]) .. " : " .. tostring(cs[i][2])) + print(tostring(cs[i][2]) .. " : " .. tostring(cs[i][4]) .. " : " .. tostring(cs[i][3])) end diff --git a/tests/lua/coe5.lua b/tests/lua/coe5.lua index e281c06f7..3cd0ff9e8 100644 --- a/tests/lua/coe5.lua +++ b/tests/lua/coe5.lua @@ -25,7 +25,7 @@ for_each_coercion_user(env, function(C, D, f) print(tostring(C) .. " >-> " .. to print(get_coercions_to_sort(env, Const("abelian_ring", {1})):head()) assert(env:type_check(get_coercions_to_sort(env, Const("abelian_ring", {1})):head())) print("Coercions (abelian ring): ") -cs = get_user_coercions(env, ab_ring) +cs = get_coercions_from(env, ab_ring) for i = 1, #cs do - print(tostring(cs[i][1]) .. " : " .. tostring(cs[i][3]) .. " : " .. tostring(cs[i][2])) + print(tostring(cs[i][2]) .. " : " .. tostring(cs[i][4]) .. " : " .. tostring(cs[i][3])) end