diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index f128646b6..48ed6bff3 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -349,6 +349,16 @@ expr const & get_app_fn(expr const & e) { return *it; } +unsigned get_app_num_args(expr const & e) { + expr const * it = &e; + unsigned n = 0; + while (is_app(*it)) { + it = &(app_fn(*it)); + n++; + } + return n; +} + static name g_default_var_name("a"); bool is_default_var_name(name const & n) { return n == g_default_var_name; } expr mk_arrow(expr const & t, expr const & e) { return mk_pi(g_default_var_name, t, e); } diff --git a/src/kernel/expr.h b/src/kernel/expr.h index bf0e496a5..4c93438dd 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -604,11 +604,10 @@ expr const & get_app_args(expr const & e, buffer & args); If e is of the form (...(f a1) ... an), then the procedure stores [an, ..., a1] in \c args. */ expr const & get_app_rev_args(expr const & e, buffer & args); -/** - \brief Given of the form (...(f a1) ... an), return \c f. If \c e is not an application, - then return \c e. -*/ +/** \brief Given \c e of the form (...(f a_1) ... a_n), return \c f. If \c e is not an application, then return \c e. */ expr const & get_app_fn(expr const & e); +/** \brief Given \c e of the form (...(f a_1) ... a_n), return \c n. If \c e is not an application, then return 0. */ +unsigned get_app_num_args(expr const & e); /** \brief Return the name of constant, local, metavar */ inline name const & named_expr_name(expr const & e) { return is_constant(e) ? const_name(e) : mlocal_name(e); } // ======================================= diff --git a/src/library/CMakeLists.txt b/src/library/CMakeLists.txt index c97a356f4..dac1f39df 100644 --- a/src/library/CMakeLists.txt +++ b/src/library/CMakeLists.txt @@ -1,7 +1,7 @@ add_library(library deep_copy.cpp expr_lt.cpp io_state.cpp occurs.cpp kernel_bindings.cpp io_state_stream.cpp bin_app.cpp resolve_macro.cpp kernel_serializer.cpp max_sharing.cpp - normalize.cpp shared_environment.cpp module.cpp) + normalize.cpp shared_environment.cpp module.cpp coercion.cpp) # placeholder.cpp fo_unify.cpp hop_match.cpp) target_link_libraries(library ${LEAN_LIBS}) diff --git a/src/library/coercion.cpp b/src/library/coercion.cpp new file mode 100644 index 000000000..f6c1f90de --- /dev/null +++ b/src/library/coercion.cpp @@ -0,0 +1,411 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include +#include +#include "util/rb_map.h" +#include "util/sstream.h" +#include "kernel/instantiate.h" +#include "library/coercion.h" +#include "library/module.h" +#include "library/kernel_serializer.h" + +namespace lean { +enum class coercion_class_kind { User, Sort, Fun }; + +/** + \brief A coercion is a mapping between classes. + We support three kinds of classes: User, Sort, Function. +*/ +class coercion_class { + coercion_class_kind m_kind; + name m_name; // relevant only if m_kind == User + coercion_class(coercion_class_kind k, name const & n = name()):m_kind(k), m_name(n) {} +public: + coercion_class():m_kind(coercion_class_kind::Sort) {} + static coercion_class mk_user(name n) { return coercion_class(coercion_class_kind::User, n); } + static coercion_class mk_sort() { return coercion_class(coercion_class_kind::Sort); } + static coercion_class mk_fun() { return coercion_class(coercion_class_kind::Fun); } + friend bool operator==(coercion_class const & c1, coercion_class const & c2) { return c1.m_kind == c2.m_kind && c1.m_name == c2.m_name; } + friend bool operator!=(coercion_class const & c1, coercion_class const & c2) { return !(c1 == c2); } + coercion_class_kind kind() const { return m_kind; } + name get_name() const { return m_name; } +}; + +std::ostream & operator<<(std::ostream & out, coercion_class const & cls) { + switch (cls.kind()) { + case coercion_class_kind::User: out << cls.get_name(); break; + case coercion_class_kind::Sort: out << "Sort-class"; break; + case coercion_class_kind::Fun: out << "Function-class"; break; + } + return out; +} + +struct coercion_class_cmp_fn { + int operator()(coercion_class const & c1, coercion_class const & c2) const { + if (c1.kind() != c2.kind()) + return c1.kind() < c2.kind() ? -1 : 1; + else + return quick_cmp(c1.get_name(), c2.get_name()); + } +}; + +struct coercion_info { + expr m_fun; + expr m_fun_type; + level_param_names m_level_params; + unsigned m_num_args; + coercion_class m_to; + coercion_info() {} + coercion_info(expr const & f, expr const & f_type, level_param_names const & ls, unsigned num, coercion_class const & cls): + m_fun(f), m_fun_type(f_type), m_level_params(ls), m_num_args(num), m_to(cls) {} +}; + +struct coercion_ext : public environment_extension { + rb_map, name_quick_cmp> m_from; + rb_map, coercion_class_cmp_fn> m_to; + rb_tree m_coercions; +}; + +struct coercion_ext_reg { + unsigned m_ext_id; + coercion_ext_reg() { m_ext_id = environment::register_extension(std::make_shared()); } +}; + +static coercion_ext_reg g_ext; +static coercion_ext const & get_extension(environment const & env) { + return static_cast(env.get_extension(g_ext.m_ext_id)); +} +static environment update(environment const & env, coercion_ext const & ext) { + return env.update(g_ext.m_ext_id, std::make_shared(ext)); +} + +// key used for module serialization +static std::string g_coercion_key = "coerce"; +static void coercion_reader(deserializer & d, module_idx, shared_environment &, + std::function &, + std::function & add_delayed_update) { + name f, C; + d >> f >> C; + add_delayed_update([=](environment const & env, io_state const & ios) -> environment { + return add_coercion(env, f, C, ios); + }); +} +register_module_object_reader_fn g_coercion_reader(g_coercion_key, coercion_reader); + +static void check_pi(name const & f, expr const & t) { + if (!is_pi(t)) + throw exception(sstream() << "invalid coercion, '" << f << "' is not function"); +} + +/** \brief Return true iff args contains Var(0), Var(1), ..., Var(args.size() - 1) */ +static bool check_var_args(buffer const & args) { + for (unsigned i = 0; i < args.size(); i++) { + if (!is_var(args[i]) || var_idx(args[i]) != i) + return false; + } + return true; +} + +/** \brief Return true iff param_id(levels[i]) == level_params[i] */ +static bool check_levels(levels ls, level_param_names ps) { + while (!is_nil(ls) && !is_nil(ps)) { + if (!is_param(head(ls))) + return false; + if (param_id(head(ls)) != head(ps)) + return false; + ls = tail(ls); + ps = tail(ps); + } + return is_nil(ls) && is_nil(ps); +} + +optional type_to_coercion_class(expr const & t) { + if (is_sort(t)) { + return optional(coercion_class::mk_sort()); + } else if (is_pi(t)) { + return optional(coercion_class::mk_fun()); + } else { + expr const & C = get_app_fn(t); + if (is_constant(C)) + return optional(coercion_class::mk_user(const_name(C))); + else + return optional(); + } +} + +static void add_coercion(coercion_ext & ext, name const & C, expr const & f, expr const & f_type, + level_param_names const & ls, unsigned num_args, + coercion_class const & cls, io_state const & ios); + +static void add_coercion_trans(coercion_ext & ext, io_state const & ios, name const & C, + level_param_names const & f_level_params, expr const & f, expr const & f_type, unsigned f_num_args, + level_param_names const & g_level_params, expr g, expr const & g_type, unsigned g_num_args) { + expr t = f_type; + buffer f_arg_names; + buffer f_arg_types; + buffer args; + unsigned i = f_num_args; + while (is_pi(t)) { + f_arg_names.push_back(binding_name(t)); + f_arg_types.push_back(binding_domain(t)); + t = binding_body(t); + i--; + args.push_back(mk_var(i)); + } + expr f_app = apply_beta(f, args.size(), args.data()); + expr D_app = t; + buffer g_args; + expr D_cnst = get_app_args(D_app, g_args); + if (g_args.size() != g_num_args) + return; + g_args.push_back(f_app); + if (length(const_levels(D_cnst)) != length(g_level_params)) + return; + // C >-> D >-> E + g = instantiate_params(g, g_level_params, const_levels(D_cnst)); + expr gf = apply_beta(g, g_args.size(), g_args.data()); + expr gf_type = g_type; + while (is_pi(gf_type)) + gf_type = binding_body(gf_type); + coercion_class new_cls = *type_to_coercion_class(gf_type); + gf_type = instantiate(instantiate_params(gf_type, g_level_params, const_levels(D_cnst)), g_args.size(), g_args.data()); + i = f_arg_types.size(); + while (i > 0) { + --i; + gf = mk_lambda(f_arg_names[i], f_arg_types[i], gf); + gf_type = mk_pi(f_arg_names[i], f_arg_types[i], gf_type); + } + add_coercion(ext, C, gf, gf_type, f_level_params, f_num_args, new_cls, ios); +} + +static void add_coercion_trans_to(coercion_ext & ext, name const & C, expr const & f, expr const & f_type, + level_param_names const & ls, unsigned num_args, + io_state const & ios) { + // apply transitivity using ext.m_to + coercion_class C_cls = coercion_class::mk_user(C); + auto it1 = ext.m_to.find(C_cls); + if (!it1) + return; + for (name const & B : *it1) { + auto it2 = ext.m_from.find(B); + lean_assert(*it2); + for (coercion_info const & info : *it2) { + if (info.m_to == C_cls) { + // B >-> C >-> D + add_coercion_trans(ext, ios, B, + info.m_level_params, info.m_fun, info.m_fun_type, info.m_num_args, + ls, f, f_type, num_args); + break; + } + } + } +} + +static void add_coercion_trans_from(coercion_ext & ext, name const & C, expr const & f, expr const & f_type, + level_param_names const & ls, unsigned num_args, + coercion_class const & cls, io_state const & ios) { + // apply transitivity using ext.m_from + if (cls.kind() != coercion_class_kind::User) + return; // nothing to do Sort and Fun classes are terminal + name const & D = cls.get_name(); + auto it = ext.m_from.find(D); + if (!it) + return; + for (coercion_info const & D_info : *it) { + // C >-> D >-> E + add_coercion_trans(ext, ios, C, + ls, f, f_type, num_args, + D_info.m_level_params, D_info.m_fun, D_info.m_fun_type, D_info.m_num_args); + } +} + +// Add entry (D, C) to ext.m_to +static void update_to(coercion_ext & ext, coercion_class const & D, name const & C) { + auto it = ext.m_to.find(D); + if (it) { + ext.m_to.insert(D, list(C, *it)); + } else { + ext.m_to.insert(D, list(C)); + } +} + +static void add_coercion(coercion_ext & ext, name const & C, expr const & f, expr const & f_type, + level_param_names const & ls, unsigned num_args, + coercion_class const & cls, io_state const & ios) { + if (cls.kind() == coercion_class_kind::User && C == cls.get_name()) + return; + auto it = ext.m_from.find(C); + if (!it) { + list infos(coercion_info(f, f_type, ls, num_args, cls)); + ext.m_from.insert(C, infos); + update_to(ext, cls, C); + } else { + list infos = *it; + bool found = false; + for_each(infos, [&](coercion_info const & info) { + if (info.m_to == cls) + ios.get_diagnostic_channel() << "replacing the coercion from '" << C << "' to '" << cls << "'"; + if (is_constant(info.m_fun)) + ext.m_coercions.erase(const_name(info.m_fun)); + found = true; + }); + if (found) + infos = filter(infos, [&](coercion_info const & info) { return info.m_to != cls; }); + infos = list(coercion_info(f, f_type, ls, num_args, cls), infos); + ext.m_from.insert(C, infos); + if (!found) + update_to(ext, cls, C); + } + if (is_constant(f)) + ext.m_coercions.insert(const_name(f)); +} + +static environment add_coercion(environment env, name const & C, expr const & f, expr const & f_type, + level_param_names const & ls, unsigned num_args, + coercion_class const & cls, io_state const & ios) { + coercion_ext ext = get_extension(env); + add_coercion_trans_to(ext, C, f, f_type, ls, num_args, ios); + add_coercion_trans_from(ext, C, f, f_type, ls, num_args, cls, ios); + add_coercion(ext, C, f, f_type, ls, num_args, cls, ios); + name const & f_name = const_name(f); + env = add(env, g_coercion_key, [=](serializer & s) { + s << f_name << C; + }); + return update(env, ext); +} + +environment add_coercion(environment const & env, name const & f, name const & C, io_state const & ios) { + declaration d = env.get(f); + unsigned num = 0; + buffer args; + expr t = d.get_type(); + check_pi(f, t); + while (true) { + args.clear(); + expr const & C_fn = get_app_rev_args(binding_domain(t), args); + if (is_constant(C_fn) && + const_name(C_fn) == C && + num == args.size() && + check_var_args(args) && + check_levels(const_levels(C_fn), d.get_params())) { + expr fn = mk_constant(f, const_levels(C_fn)); + optional cls = type_to_coercion_class(binding_body(t)); + if (!cls) + throw exception(sstream() << "invalid coercion, '" << f << "' does not have the expected type to be used as a coercion"); + else if (cls->kind() == coercion_class_kind::User && cls->get_name() == C) + throw exception(sstream() << "invalid coercion, '" << f << "' is a coercion from '" << C << "' on itself"); + return add_coercion(env, C, fn, d.get_type(), d.get_params(), num, *cls, ios); + } + t = binding_body(t); + num++; + check_pi(f, t); + } +} + +environment add_coercion(environment const & env, name const & f, io_state const & ios) { + declaration d = env.get(f); + expr t = d.get_type(); + check_pi(f, t); + while (is_pi(binding_body(t))) + t = binding_body(t); + expr C = get_app_fn(binding_domain(t)); + if (!is_constant(C)) + throw exception(sstream() << "invalid coercion, '" << f << "' does not have the expected type to be used as a coercion"); + return add_coercion(env, f, const_name(C), ios); +} + +bool is_coercion(environment const & env, expr const & f) { + if (is_constant(f)) { + coercion_ext const & ext = get_extension(env); + return ext.m_coercions.contains(const_name(f)); + } else { + return false; + } +} + +bool has_coercions_from(environment const & env, name const & C) { + coercion_ext const & ext = get_extension(env); + return ext.m_from.contains(C); +} + +bool has_coercions_from(environment const & env, expr const & C) { + expr const & C_fn = get_app_fn(C); + if (!is_constant(C_fn)) + return false; + coercion_ext const & ext = get_extension(env); + auto it = ext.m_from.find(const_name(C_fn)); + if (!it) + return false; + list const & cs = *it; + return + head(cs).m_num_args == get_app_num_args(C) && + length(head(cs).m_level_params) == length(const_levels(C_fn)); +} + +bool has_coercions_to(environment const & env, name const & D) { + coercion_ext const & ext = get_extension(env); + return ext.m_to.contains(coercion_class::mk_user(D)); +} + +bool has_coercions_to(environment const & env, expr const & D) { + expr const & D_fn = get_app_fn(D); + return is_constant(D_fn) && has_coercions_to(env, const_name(D_fn)); +} + +optional get_coercion(environment const & env, expr const & C, coercion_class const & D) { + buffer args; + expr const & C_fn = get_app_rev_args(C, args); + if (!is_constant(C_fn)) + return none_expr(); + coercion_ext const & ext = get_extension(env); + auto it = ext.m_from.find(const_name(C_fn)); + if (!it) + return none_expr(); + for (coercion_info const & info : *it) { + if (info.m_to == D && info.m_num_args == args.size() && length(info.m_level_params) == length(const_levels(C_fn))) { + expr f = instantiate_params(info.m_fun, info.m_level_params, const_levels(C_fn)); + return some_expr(apply_beta(f, args.size(), args.data())); + } + } + return none_expr(); +} + +optional get_coercion(environment const & env, expr const & C, name const & D) { + return get_coercion(env, C, coercion_class::mk_user(D)); +} + +optional get_coercion_to_sort(environment const & env, expr const & C) { + return get_coercion(env, C, coercion_class::mk_sort()); +} + +optional get_coercion_to_fun(environment const & env, expr const & C) { + return get_coercion(env, C, coercion_class::mk_fun()); +} + +bool get_user_coercions(environment const & env, expr const & C, buffer> & result) { + buffer args; + expr const & C_fn = get_app_rev_args(C, args); + if (!is_constant(C_fn)) + return false; + coercion_ext const & ext = get_extension(env); + auto it = ext.m_from.find(const_name(C_fn)); + if (!it) + return false; + bool r = false; + for (coercion_info const & info : *it) { + if (info.m_to.kind() == coercion_class_kind::User && + info.m_num_args == args.size() && + length(info.m_level_params) == length(const_levels(C_fn))) { + expr f = instantiate_params(info.m_fun, info.m_level_params, const_levels(C_fn)); + expr c = apply_beta(f, args.size(), args.data()); + result.emplace_back(c, info.m_to.get_name()); + r = true; + } + } + return r; +} +} diff --git a/src/library/coercion.h b/src/library/coercion.h index 1f49fd8c6..462baffcc 100644 --- a/src/library/coercion.h +++ b/src/library/coercion.h @@ -5,7 +5,11 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #pragma once +#include #include "kernel/environment.h" +#include "library/expr_pair.h" +#include "library/io_state.h" + namespace lean { /** \brief Add an new coercion in the given environment. @@ -28,11 +32,32 @@ namespace lean { {l1 ... lk} Pi (x1 : A1) ... (xn : An) (y: C.{l1 ... lk} x1 ... xn), Type.{L} \pre \c f is a constant defined in \c env. + + \remark \c ios is used to report warning messages. */ -environment add_coercion(environment const & env, expr const & f); +environment add_coercion(environment const & env, name & f, io_state const & ios); +environment add_coercion(environment const & env, name const & f, name const & C, io_state const & ios); bool is_coercion(environment const & env, expr const & f); -list get_coercions(environment const & env, expr const & from, expr const & from_type); -optional get_coercion(environment const & env, expr const & from, expr const & from_type, expr const & to_type); -optional get_coercion_to_sort(environment const & env, expr const & from, expr const & from_type); -optional get_coercion_to_fun(environment const & env, expr const & from, expr const & from_type); +/** \brief Return true iff the given environment has coercions from a user-class named \c C. */ +bool has_coercions_from(environment const & env, name const & C); +bool has_coercions_from(environment const & env, expr const & C); +/** \brief Return true iff the given environment has coercions to a user-class named \c D. */ +bool has_coercions_to(environment const & env, name const & D); +bool has_coercions_to(environment const & env, expr const & D); +/** + \brief Return a coercion (if it exists) from (C_name.{l1 lk} t_1 ... t_n) to the class named D. + The coercion is a unary function that takes a term of type (C_name.{l1 lk} t_1 ... t_n) and returns + and element of type (D.{L_1 L_o} s_1 ... s_m) +*/ +optional get_coercion(environment const & env, expr const & C, name const & D); +optional get_coercion_to_sort(environment const & env, expr const & C); +optional get_coercion_to_fun(environment const & env, expr const & C); +/** + \brief Return all user coercions C >-> D for the type C of the form (C_name.{l1 ... lk} t_1 ... t_n) + The result is a pair (coercion, user-class D), and is stored in the result buffer \c result. + The Boolean result is true if at least one pair is added to \c result. + + \remark The most recent coercions occur first. +*/ +bool get_user_coercions(environment const & env, expr const & C, buffer> & result); }