feat(frontends/lean): use coercions to function-class and sort-class in

function arguments, closes #203
This commit is contained in:
Leonardo de Moura 2014-09-20 09:00:10 -07:00
parent 7262cce37a
commit fd5daa8fda
10 changed files with 213 additions and 93 deletions

View file

@ -21,6 +21,23 @@ coercion_elaborator::coercion_elaborator(coercion_info_manager & info, expr cons
lean_assert(use_id || length(m_coercions) == length(m_choices));
}
list<expr> get_coercions_from_to(type_checker & tc, expr const & from_type, expr const & to_type, constraint_seq & cs) {
constraint_seq new_cs;
expr whnf_to_type = tc.whnf(to_type, new_cs);
expr const & fn = get_app_fn(whnf_to_type);
list<expr> r;
if (is_constant(fn)) {
r = get_coercions(tc.env(), from_type, const_name(fn));
} else if (is_pi(whnf_to_type)) {
r = get_coercions_to_fun(tc.env(), from_type);
} else if (is_sort(whnf_to_type)) {
r = get_coercions_to_sort(tc.env(), from_type);
}
if (r)
cs += new_cs;
return r;
}
optional<constraints> coercion_elaborator::next() {
if (!m_choices)
return optional<constraints>();
@ -65,8 +82,8 @@ constraint mk_coercion_cnstr(type_checker & tc, coercion_info_manager & infom,
new_a_type = tc.whnf(new_a_type, cs);
if (is_meta(d_type)) {
// case-split
buffer<std::tuple<name, expr, expr>> alts;
get_user_coercions(tc.env(), new_a_type, alts);
buffer<std::tuple<coercion_class, expr, expr>> alts;
get_coercions_from(tc.env(), new_a_type, alts);
buffer<constraints> choices;
buffer<expr> coes;
// first alternative: no coercion
@ -86,33 +103,24 @@ constraint mk_coercion_cnstr(type_checker & tc, coercion_info_manager & infom,
to_list(choices.begin(), choices.end()),
to_list(coes.begin(), coes.end())));
} else {
expr new_d_type = tc.whnf(d_type, cs);
expr const & d_cls = get_app_fn(new_d_type);
if (is_constant(d_cls)) {
list<expr> coes = get_coercions(tc.env(), new_a_type, const_name(d_cls));
if (is_nil(coes)) {
expr new_a = a;
infom.erase_coercion_info(a);
cs += mk_eq_cnstr(meta, new_a, new_a_type_jst, relax);
return lazy_list<constraints>(cs.to_list());
} else if (is_nil(tail(coes))) {
expr new_a = copy_tag(a, mk_app(head(coes), a));
infom.save_coercion_info(a, new_a);
cs += mk_eq_cnstr(meta, new_a, new_a_type_jst, relax);
return lazy_list<constraints>(cs.to_list());
} else {
list<constraints> choices = map2<constraints>(coes, [&](expr const & coe) {
expr new_a = copy_tag(a, mk_app(coe, a));
constraint c = mk_eq_cnstr(meta, new_a, new_a_type_jst, relax);
return (cs + c).to_list();
});
return choose(std::make_shared<coercion_elaborator>(infom, meta, choices, coes, false));
}
} else {
list<expr> coes = get_coercions_from_to(tc, new_a_type, d_type, cs);
if (is_nil(coes)) {
expr new_a = a;
infom.erase_coercion_info(a);
cs += mk_eq_cnstr(meta, new_a, new_a_type_jst, relax);
return lazy_list<constraints>(cs.to_list());
} else if (is_nil(tail(coes))) {
expr new_a = copy_tag(a, mk_app(head(coes), a));
infom.save_coercion_info(a, new_a);
cs += mk_eq_cnstr(meta, new_a, new_a_type_jst, relax);
return lazy_list<constraints>(cs.to_list());
} else {
list<constraints> choices = map2<constraints>(coes, [&](expr const & coe) {
expr new_a = copy_tag(a, mk_app(coe, a));
constraint c = mk_eq_cnstr(meta, new_a, new_a_type_jst, relax);
return (cs + c).to_list();
});
return choose(std::make_shared<coercion_elaborator>(infom, meta, choices, coes, false));
}
}
};

View file

@ -51,4 +51,6 @@ pair<expr, constraint> mk_coercion_elaborator(
pair<expr, constraint> coercions_to_choice(coercion_info_manager & infom, local_context & ctx,
list<expr> const & coes, expr const & a,
justification const & j, bool relax);
list<expr> get_coercions_from_to(type_checker & tc, expr const & from_type, expr const & to_type, constraint_seq & cs);
}

View file

@ -434,37 +434,40 @@ public:
return is_constant(a_cls) && ::lean::has_coercions_from(env(), const_name(a_cls));
}
bool has_coercions_to(expr const & d_type) {
expr const & d_cls = get_app_fn(whnf(d_type).first);
return is_constant(d_cls) && ::lean::has_coercions_to(env(), const_name(d_cls));
bool has_coercions_to(expr d_type) {
d_type = whnf(d_type).first;
expr const & fn = get_app_fn(d_type);
if (is_constant(fn))
return ::lean::has_coercions_to(env(), const_name(fn));
else if (is_pi(d_type))
return ::lean::has_coercions_to_fun(env());
else if (is_sort(d_type))
return ::lean::has_coercions_to_sort(env());
else
return false;
}
expr apply_coercion(expr const & a, expr a_type, expr d_type) {
a_type = whnf(a_type).first;
d_type = whnf(d_type).first;
expr const & d_cls = get_app_fn(d_type);
if (is_constant(d_cls)) {
list<expr> coes = get_coercions(env(), a_type, const_name(d_cls));
if (is_nil(coes)) {
erase_coercion_info(a);
return a;
} else if (is_nil(tail(coes))) {
expr r = mk_app(head(coes), a, a.get_tag());
save_coercion_info(a, r);
return r;
} else {
for (expr const & coe : coes) {
expr r = mk_app(coe, a, a.get_tag());
expr r_type = infer_type(r).first;
if (m_tc[m_relax_main_opaque]->is_def_eq(r_type, d_type).first) {
save_coercion_info(a, r);
return r;
}
}
erase_coercion_info(a);
return a;
}
constraint_seq aux_cs;
list<expr> coes = get_coercions_from_to(*m_tc[m_relax_main_opaque], a_type, d_type, aux_cs);
if (is_nil(coes)) {
erase_coercion_info(a);
return a;
} else if (is_nil(tail(coes))) {
expr r = mk_app(head(coes), a, a.get_tag());
save_coercion_info(a, r);
return r;
} else {
for (expr const & coe : coes) {
expr r = mk_app(coe, a, a.get_tag());
expr r_type = infer_type(r).first;
if (m_tc[m_relax_main_opaque]->is_def_eq(r_type, d_type).first) {
save_coercion_info(a, r);
return r;
}
}
erase_coercion_info(a);
return a;
}

View file

@ -16,28 +16,15 @@ Author: Leonardo de Moura
#include "library/scoped_ext.h"
namespace lean {
enum class coercion_class_kind { User, Sort, Fun };
/**
\brief A coercion is a mapping between classes.
We support three kinds of classes: User, Sort, Function.
*/
class coercion_class {
coercion_class_kind m_kind;
name m_name; // relevant only if m_kind == User
coercion_class(coercion_class_kind k, name const & n = name()):m_kind(k), m_name(n) {}
public:
coercion_class():m_kind(coercion_class_kind::Sort) {}
static coercion_class mk_user(name n) { return coercion_class(coercion_class_kind::User, n); }
static coercion_class mk_sort() { return coercion_class(coercion_class_kind::Sort); }
static coercion_class mk_fun() { return coercion_class(coercion_class_kind::Fun); }
friend bool operator==(coercion_class const & c1, coercion_class const & c2) {
return c1.m_kind == c2.m_kind && c1.m_name == c2.m_name;
}
friend bool operator!=(coercion_class const & c1, coercion_class const & c2) { return !(c1 == c2); }
coercion_class_kind kind() const { return m_kind; }
name get_name() const { return m_name; }
};
coercion_class coercion_class::mk_user(name n) { return coercion_class(coercion_class_kind::User, n); }
coercion_class coercion_class::mk_sort() { return coercion_class(coercion_class_kind::Sort); }
coercion_class coercion_class::mk_fun() { return coercion_class(coercion_class_kind::Fun); }
bool operator==(coercion_class const & c1, coercion_class const & c2) {
return c1.m_kind == c2.m_kind && c1.m_name == c2.m_name;
}
bool operator!=(coercion_class const & c1, coercion_class const & c2) {
return !(c1 == c2);
}
std::ostream & operator<<(std::ostream & out, coercion_class const & cls) {
switch (cls.kind()) {
@ -415,6 +402,18 @@ bool has_coercions_to(environment const & env, name const & D) {
return it && !is_nil(*it);
}
bool has_coercions_to_sort(environment const & env) {
coercion_state const & ext = coercion_ext::get_state(env);
auto it = ext.m_to.find(coercion_class::mk_sort());
return it && !is_nil(*it);
}
bool has_coercions_to_fun(environment const & env) {
coercion_state const & ext = coercion_ext::get_state(env);
auto it = ext.m_to.find(coercion_class::mk_fun());
return it && !is_nil(*it);
}
bool has_coercions_from(environment const & env, name const & C) {
coercion_state const & ext = coercion_ext::get_state(env);
return ext.m_coercion_info.contains(C);
@ -465,7 +464,7 @@ list<expr> get_coercions_to_fun(environment const & env, expr const & C) {
return get_coercions(env, C, coercion_class::mk_fun());
}
bool get_user_coercions(environment const & env, expr const & C, buffer<std::tuple<name, expr, expr>> & result) {
bool get_coercions_from(environment const & env, expr const & C, buffer<std::tuple<coercion_class, expr, expr>> & result) {
buffer<expr> args;
expr const & C_fn = get_app_rev_args(C, args);
if (!is_constant(C_fn))
@ -476,15 +475,14 @@ bool get_user_coercions(environment const & env, expr const & C, buffer<std::tup
return false;
bool r = false;
for (coercion_info const & info : *it) {
if (info.m_to.kind() == coercion_class_kind::User &&
info.m_num_args == args.size() &&
if (info.m_num_args == args.size() &&
length(info.m_level_params) == length(const_levels(C_fn))) {
expr f = instantiate_univ_params(info.m_fun, info.m_level_params, const_levels(C_fn));
expr c = apply_beta(f, args.size(), args.data());
expr t = instantiate_univ_params(info.m_fun_type, info.m_level_params, const_levels(C_fn));
for (unsigned i = 0; i < args.size(); i++) t = binding_body(t);
t = instantiate(t, args.size(), args.data());
result.emplace_back(info.m_to.get_name(), c, t);
result.emplace_back(info.m_to, c, t);
r = true;
}
}
@ -568,19 +566,26 @@ static int get_coercions_to_fun(lua_State * L) {
return push_list_expr(L, get_coercions_to_fun(to_environment(L, 1), to_expr(L, 2)));
}
static int get_user_coercions(lua_State * L) {
buffer<std::tuple<name, expr, expr>> r;
get_user_coercions(to_environment(L, 1), to_expr(L, 2), r);
static int get_coercions_from(lua_State * L) {
buffer<std::tuple<coercion_class, expr, expr>> r;
get_coercions_from(to_environment(L, 1), to_expr(L, 2), r);
lua_newtable(L);
int i = 1;
for (auto p : r) {
lua_newtable(L);
push_name(L, std::get<0>(p));
coercion_class c = std::get<0>(p);
push_integer(L, static_cast<unsigned>(c.kind()));
lua_rawseti(L, -2, 1);
push_expr(L, std::get<1>(p));
if (c.kind() == coercion_class_kind::User) {
push_name(L, c.get_name());
} else {
push_nil(L);
}
lua_rawseti(L, -2, 2);
push_expr(L, std::get<2>(p));
push_expr(L, std::get<1>(p));
lua_rawseti(L, -2, 3);
push_expr(L, std::get<2>(p));
lua_rawseti(L, -2, 4);
lua_rawseti(L, -2, i);
i = i + 1;
}
@ -616,10 +621,10 @@ void open_coercion(lua_State * L) {
SET_GLOBAL_FUN(add_coercion, "add_coercion");
SET_GLOBAL_FUN(is_coercion, "is_coercion");
SET_GLOBAL_FUN(has_coercions_from, "has_coercions_from");
SET_GLOBAL_FUN(get_coercions, "get_coercions");
SET_GLOBAL_FUN(get_coercions_to_sort, "get_coercions_to_sort");
SET_GLOBAL_FUN(get_coercions_to_fun, "get_coercions_to_fun");
SET_GLOBAL_FUN(get_user_coercions, "get_user_coercions");
SET_GLOBAL_FUN(get_coercions, "get_coercions");
SET_GLOBAL_FUN(get_coercions_to_sort, "get_coercions_to_sort");
SET_GLOBAL_FUN(get_coercions_to_fun, "get_coercions_to_fun");
SET_GLOBAL_FUN(get_coercions_from, "get_coercions_from");
SET_GLOBAL_FUN(for_each_coercion_user, "for_each_coercion_user");
SET_GLOBAL_FUN(for_each_coercion_sort, "for_each_coercion_sort");
SET_GLOBAL_FUN(for_each_coercion_fun, "for_each_coercion_fun");

View file

@ -12,6 +12,26 @@ Author: Leonardo de Moura
#include "library/io_state.h"
namespace lean {
enum class coercion_class_kind { User, Sort, Fun };
/**
\brief A coercion is a mapping between classes.
We support three kinds of classes: User, Sort, Function.
*/
class coercion_class {
coercion_class_kind m_kind;
name m_name; // relevant only if m_kind == User
coercion_class(coercion_class_kind k, name const & n = name()):m_kind(k), m_name(n) {}
public:
coercion_class():m_kind(coercion_class_kind::Sort) {}
static coercion_class mk_user(name n);
static coercion_class mk_sort();
static coercion_class mk_fun();
friend bool operator==(coercion_class const & c1, coercion_class const & c2);
friend bool operator!=(coercion_class const & c1, coercion_class const & c2);
coercion_class_kind kind() const { return m_kind; }
name get_name() const { return m_name; }
};
/**
\brief Add an new coercion in the given environment.
@ -51,6 +71,8 @@ bool has_coercions_from(environment const & env, name const & C);
bool has_coercions_from(environment const & env, expr const & C);
/** \brief Return true iff the given environment has coercions to a user-class named \c D. */
bool has_coercions_to(environment const & env, name const & D);
bool has_coercions_to_sort(environment const & env);
bool has_coercions_to_fun(environment const & env);
/**
\brief Return a coercion (if it exists) from (C_name.{l1 lk} t_1 ... t_n) to the class named D.
The coercion is a unary function that takes a term of type (C_name.{l1 lk} t_1 ... t_n) and returns
@ -60,13 +82,13 @@ list<expr> get_coercions(environment const & env, expr const & C, name const & D
list<expr> get_coercions_to_sort(environment const & env, expr const & C);
list<expr> get_coercions_to_fun(environment const & env, expr const & C);
/**
\brief Return all user coercions C >-> D for the type C of the form (C_name.{l1 ... lk} t_1 ... t_n)
The result is a pair (user-class D, coercion, coercion type), and is stored in the result buffer \c result.
\brief Return all coercions C >-> D for the type C of the form (C_name.{l1 ... lk} t_1 ... t_n)
The result is a tuple (class D, coercion, coercion type), and is stored in the result buffer \c result.
The Boolean result is true if at least one pair is added to \c result.
\remark The most recent coercions occur first.
*/
bool get_user_coercions(environment const & env, expr const & C, buffer<std::tuple<name, expr, expr>> & result);
bool get_coercions_from(environment const & env, expr const & C, buffer<std::tuple<coercion_class, expr, expr>> & result);
typedef std::function<void(name const &, name const &, expr const &, level_param_names const &, unsigned)> coercion_user_fn;
typedef std::function<void(name const &, expr const &, level_param_names const &, unsigned)> coercion_sort_fn;

30
tests/lean/run/coe13.lean Normal file
View file

@ -0,0 +1,30 @@
import data.nat
open nat
inductive functor (A B : Type) :=
mk : (A → B) → functor A B
definition functor.to_fun [coercion] {A B : Type} (f : functor A B) : A → B :=
functor.rec (λ f, f) f
inductive struct :=
mk : Π (A : Type), (A → A → Prop) → struct
definition struct.to_sort [coercion] (s : struct) : Type :=
struct.rec (λA r, A) s
definition g (f : nat → nat) (a : nat) := f a
variable f : functor nat nat
check g (functor.to_fun f) 0
check g f 0
definition id (A : Type) (a : A) := a
variable S : struct
variable a : S
check id (struct.to_sort S) a
check id S a

30
tests/lean/run/coe14.lean Normal file
View file

@ -0,0 +1,30 @@
import data.nat
open nat
inductive functor (A B : Type) :=
mk : (A → B) → functor A B
definition functor.to_fun [coercion] {A B : Type} (f : functor A B) : A → B :=
functor.rec (λ f, f) f
inductive struct :=
mk : Π (A : Type), (A → A → Prop) → struct
definition struct.to_sort [coercion] (s : struct) : Type :=
struct.rec (λA r, A) s
definition g (f : nat → nat) (a : nat) := f a
variable f : functor nat nat
check g (functor.to_fun f) 0
check g f 0
definition id (A : Type) (a : A) := a
variable S : struct
variable a : S
check id (struct.to_sort S) a
check id S a

20
tests/lean/run/coe15.lean Normal file
View file

@ -0,0 +1,20 @@
import data.nat
open nat
inductive functor (A B : Type) :=
mk : (A → B) → functor A B
definition functor.to_fun [coercion] {A B : Type} (f : functor A B) : A → B :=
functor.rec (λ f, f) f
inductive struct :=
mk : Π (A : Type), (A → A → Prop) → struct
definition struct.to_sort [coercion] (s : struct) : Type :=
struct.rec (λA r, A) s
definition g (f : nat → nat) (a : nat) := f a
check
λ f,
(g f 0) = 0 ∧ (functor.to_fun f) 0 = 0

View file

@ -84,7 +84,7 @@ assert(not has_coercions_from(env2, Const("vec", {1})(nat)))
assert(not has_coercions_from(env2, Const("vec")(nat, one)))
print("Coercions (vec nat one): ")
cs = get_user_coercions(env2, Const("vec", {1})(nat, one))
cs = get_coercions_from(env2, Const("vec", {1})(nat, one))
for i = 1, #cs do
print(tostring(cs[i][1]) .. " : " .. tostring(cs[i][3]) .. " : " .. tostring(cs[i][2]))
print(tostring(cs[i][2]) .. " : " .. tostring(cs[i][4]) .. " : " .. tostring(cs[i][3]))
end

View file

@ -25,7 +25,7 @@ for_each_coercion_user(env, function(C, D, f) print(tostring(C) .. " >-> " .. to
print(get_coercions_to_sort(env, Const("abelian_ring", {1})):head())
assert(env:type_check(get_coercions_to_sort(env, Const("abelian_ring", {1})):head()))
print("Coercions (abelian ring): ")
cs = get_user_coercions(env, ab_ring)
cs = get_coercions_from(env, ab_ring)
for i = 1, #cs do
print(tostring(cs[i][1]) .. " : " .. tostring(cs[i][3]) .. " : " .. tostring(cs[i][2]))
print(tostring(cs[i][2]) .. " : " .. tostring(cs[i][4]) .. " : " .. tostring(cs[i][3]))
end