feat(frontends/lean): elaborate inductive datatypes and introduction rules as a single elaboration problem

This commit is contained in:
Leonardo de Moura 2014-10-13 18:03:45 -07:00
parent 2431de542f
commit 9edf780a00
3 changed files with 278 additions and 220 deletions

View file

@ -162,7 +162,7 @@ congruence.mk (λx y H, H)
-- ---------------------------------------------------------
inductive mp_like [class] {R : Type → Type → Prop} {a b : Type} (H : R a b) : Type :=
mk {} : (a → b) → @mp_like R a b H
mk {} : (a → b) → mp_like H
namespace mp_like

View file

@ -18,7 +18,7 @@ rfl
theorem cast_eq {A : Type} (H : A = A) (a : A) : cast H a = a :=
rfl
inductive heq.{l} {A : Type.{l}} (a : A) : Π {B : Type.{l}}, B → Prop :=
inductive heq {A : Type} (a : A) : Π {B : Type}, B → Prop :=
refl : heq a a
infixl `==`:50 := heq

View file

@ -15,6 +15,7 @@ Author: Leonardo de Moura
#include "kernel/free_vars.h"
#include "library/scoped_ext.h"
#include "library/locals.h"
#include "library/deep_copy.h"
#include "library/placeholder.h"
#include "library/aliases.h"
#include "library/protected.h"
@ -94,8 +95,9 @@ struct inductive_cmd_fn {
bool m_first; // true if parsing the first inductive type in a mutually recursive inductive decl.
buffer<name> m_explicit_levels;
buffer<name> m_levels;
bool m_using_explicit_levels; // true if the user is providing explicit universe levels
buffer<expr> m_params; // parameters
unsigned m_num_params; // number of parameters
bool m_using_explicit_levels; // true if the user is providing explicit universe levels
level m_u; // temporary auxiliary global universe used for inferring the result universe of
// an inductive datatype declaration.
bool m_infer_result_universe;
@ -233,30 +235,6 @@ struct inductive_cmd_fn {
return sort_level(d_type);
}
/** \brief Update the result sort of the given type */
expr update_result_sort(expr t, level const & l) {
t = m_tc->whnf(t).first;
if (is_pi(t)) {
return update_binding(t, binding_domain(t), update_result_sort(binding_body(t), l));
} else if (is_sort(t)) {
return update_sort(t, l);
} else {
lean_unreachable();
}
}
/** \brief Elaborate the type of an inductive declaration. */
std::tuple<expr, level_param_names> elaborate_inductive_type(expr d_type) {
level l = get_datatype_result_level(d_type);
if (is_placeholder(l)) {
if (m_using_explicit_levels)
throw_error("resultant universe must be provided, when using explicit universe levels");
d_type = update_result_sort(d_type, m_u);
m_infer_result_universe = true;
}
return m_p.elaborate_at(m_env, d_type);
}
/** \brief Create a local constant based on the given binding */
expr mk_local_for(expr const & b) {
return mk_local(m_p.mk_fresh_name(), binding_name(b), binding_domain(b), binding_info(b));
@ -292,25 +270,36 @@ struct inductive_cmd_fn {
"provide all parameters explicitly to fix the problem");
}
/** \brief Add the parameters of the inductive decl type to the local scoped.
This method is executed before parsing introduction rules.
*/
void add_params_to_local_scope(expr d_type, buffer<expr> & params) {
/** \brief Set explicit datatype parameters as local constants in m_params */
void update_params(expr d_type) {
// Remark: if m_params is not empty, then this function will reuse their names.
// Reason: reference to these names
lean_assert(m_num_params >= m_params.size());
buffer<name> param_names;
for (unsigned i = 0; i < m_num_params - m_params.size(); i++)
param_names.push_back(m_p.mk_fresh_name());
for (expr const & param : m_params)
param_names.push_back(mlocal_name(param)); // keep existing internal names
m_params.clear();
for (unsigned i = 0; i < m_num_params; i++) {
expr l = mk_local(m_p.mk_fresh_name(), binding_name(d_type), mk_as_is(binding_domain(d_type)),
binding_info(d_type));
m_p.add_local(l);
params.push_back(l);
expr l = mk_local(param_names[i], binding_name(d_type), binding_domain(d_type), binding_info(d_type));
m_params.push_back(l);
d_type = instantiate(binding_body(d_type), l);
}
}
/** \brief Parse introduction rules in the scope of the given parameters.
/** \brief Add the parameters (in m_params) to parser local scope */
void add_params_to_local_scope() {
for (expr const & l : m_params)
m_p.add_local(l);
}
/** \brief Parse introduction rules in the scope of m_params.
Introduction rules with the annotation '{}' are marked for relaxed (aka non-strict) implicit parameter inference.
Introduction rules with the annotation '()' are marked for no implicit parameter inference.
*/
list<intro_rule> parse_intro_rules(name const & ind_name, buffer<expr> & params) {
list<intro_rule> parse_intro_rules(name const & ind_name) {
buffer<intro_rule> intros;
while (true) {
name intro_name = parse_intro_decl_name(ind_name);
@ -325,8 +314,7 @@ struct inductive_cmd_fn {
m_implicit_infer_map.insert(intro_name, implicit_infer_kind::None);
}
m_p.check_token_next(get_colon_tk(), "invalid introduction rule, ':' expected");
expr intro_type = m_p.parse_scoped_expr(params, m_env);
intro_type = Pi(params, intro_type, m_p);
expr intro_type = m_p.parse_expr();
intros.push_back(intro_rule(intro_name, intro_type));
if (!m_p.curr_is_token(get_comma_tk()))
break;
@ -337,7 +325,6 @@ struct inductive_cmd_fn {
void parse_inductive_decls(buffer<inductive_decl> & decls) {
optional<expr> first_d_type;
optional<level_param_names> first_d_lvls;
while (true) {
parser::local_scope l_scope(m_p);
pair<name, name> d_names = parse_inductive_decl_name();
@ -353,16 +340,12 @@ struct inductive_cmd_fn {
empty_type = false;
m_p.next();
}
level_param_names d_lvls;
std::tie(d_type, d_lvls) = elaborate_inductive_type(d_type);
if (m_first) {
m_levels.append(m_explicit_levels);
for (auto l : d_lvls) m_levels.push_back(l);
update_params(d_type);
} else {
lean_assert(first_d_type);
lean_assert(first_d_lvls);
check_params(d_type, *first_d_type);
check_levels(d_lvls, *first_d_lvls);
}
if (empty_type) {
decls.push_back(inductive_decl(d_name, d_type, list<intro_rule>()));
@ -370,9 +353,8 @@ struct inductive_cmd_fn {
expr d_const = mk_constant(d_name, param_names_to_levels(to_list(m_explicit_levels.begin(),
m_explicit_levels.end())));
m_p.add_local_expr(d_short_name, d_const);
buffer<expr> params;
add_params_to_local_scope(d_type, params);
auto d_intro_rules = parse_intro_rules(d_name, params);
add_params_to_local_scope();
auto d_intro_rules = parse_intro_rules(d_name);
decls.push_back(inductive_decl(d_name, d_type, d_intro_rules));
}
if (!m_p.curr_is_token(get_with_tk())) {
@ -381,15 +363,16 @@ struct inductive_cmd_fn {
m_p.next();
m_first = false;
first_d_type = d_type;
first_d_lvls = d_lvls;
}
}
/** \brief Include in m_levels any local level referenced by decls. */
void include_local_levels(buffer<inductive_decl> const & decls) {
void include_local_levels(buffer<inductive_decl> const & decls, buffer<expr> const & locals) {
if (!m_p.has_locals())
return;
name_set all_lvl_params;
for (auto const & local : locals) {
all_lvl_params = collect_univ_params(mlocal_type(local), all_lvl_params);
}
for (auto const & d : decls) {
all_lvl_params = collect_univ_params(inductive_decl_type(d), all_lvl_params);
for (auto const & ir : inductive_decl_intros(d)) {
@ -411,8 +394,8 @@ struct inductive_cmd_fn {
m_levels.append(new_levels);
}
/** \brief Collect local constants used in the inductive decls */
void collect_locals(buffer<inductive_decl> const & decls, expr_struct_set & ls) {
/** \brief Collect local constants used in the inductive decls. */
void collect_locals_core(buffer<inductive_decl> const & decls, expr_struct_set & ls) {
buffer<expr> include_vars;
m_p.get_include_variables(include_vars);
for (expr const & param : include_vars) {
@ -421,78 +404,112 @@ struct inductive_cmd_fn {
}
for (auto const & d : decls) {
::lean::collect_locals(inductive_decl_type(d), ls);
for (auto const & ir : inductive_decl_intros(d))
::lean::collect_locals(intro_rule_type(ir), ls);
for (auto const & ir : inductive_decl_intros(d)) {
expr ir_type = intro_rule_type(ir);
ir_type = Pi(m_params, ir_type);
::lean::collect_locals(ir_type, ls);
}
}
}
/** \brief Make sure that every occurrence of an inductive datatype (in decls) in \c type has
locals as arguments.
*/
expr fix_inductive_occs(expr const & type, buffer<inductive_decl> const & decls, buffer<expr> const & locals) {
if (locals.empty())
return type;
return replace(type, [&](expr const & e) {
if (!is_constant(e))
return none_expr();
if (!std::any_of(decls.begin(), decls.end(),
[&](inductive_decl const & d) { return const_name(e) == inductive_decl_name(d); }))
return none_expr();
// found target
expr r = mk_as_atomic(mk_app(mk_explicit(e), locals));
return some_expr(r);
});
}
/** \brief Include the used locals as additional arguments.
The locals are stored in \c locals
*/
void abstract_locals(buffer<inductive_decl> & decls, buffer<expr> & locals) {
/** \brief Collect local constants used in the declaration as extra parameters, and
update inductive datatype types with them. */
void collect_locals(buffer<inductive_decl> & decls, buffer<expr> & locals) {
if (!m_p.has_locals())
return;
expr_struct_set local_set;
collect_locals(decls, local_set);
collect_locals_core(decls, local_set);
if (local_set.empty())
return;
sort_locals(local_set, m_p, locals);
// First, add locals to inductive types type.
for (inductive_decl & d : decls) {
d = update_inductive_decl(d, Pi(locals, inductive_decl_type(d), m_p));
m_num_params += locals.size();
}
// Add locals to introduction rules type, and also "fix"
// occurrences of inductive types.
for (inductive_decl & d : decls) {
buffer<intro_rule> new_irs;
for (auto const & ir : inductive_decl_intros(d)) {
expr type = intro_rule_type(ir);
type = fix_inductive_occs(type, decls, locals);
type = Pi_as_is(locals, type, m_p);
new_irs.push_back(update_intro_rule(ir, type));
}
d = update_inductive_decl(d, new_irs);
/** \brief Update the result sort of the given type */
expr update_result_sort(expr t, level const & l) {
t = m_tc->whnf(t).first;
if (is_pi(t)) {
return update_binding(t, binding_domain(t), update_result_sort(binding_body(t), l));
} else if (is_sort(t)) {
return update_sort(t, l);
} else {
lean_unreachable();
}
}
/** \brief Declare inductive types in the scratch environment as var_decls.
We use this trick to be able to elaborate the introduction rules.
*/
void declare_inductive_types(buffer<inductive_decl> & decls) {
level_param_names ls = to_list(m_levels.begin(), m_levels.end());
for (auto const & d : decls) {
name d_name; expr d_type;
std::tie(d_name, d_type, std::ignore) = d;
m_env = m_env.add(check(m_env, mk_constant_assumption(d_name, ls, d_type)));
/** \brief Convert inductive datatype declarations into local constants, and store them into \c r and \c map.
\c map is a mapping from inductive datatype name into local constant. */
void inductive_types_to_locals(buffer<inductive_decl> const & decls, buffer<expr> & r, name_map<expr> & map) {
for (inductive_decl const & decl : decls) {
name const & n = inductive_decl_name(decl);
expr type = inductive_decl_type(decl);
for (unsigned i = 0; i < m_params.size(); i++) {
lean_assert(is_pi(type));
type = binding_body(type);
}
type = instantiate_rev(type, m_params.size(), m_params.data());
level l = get_datatype_result_level(type);
if (is_placeholder(l)) {
if (m_using_explicit_levels)
throw_error("resultant universe must be provided, when using explicit universe levels");
type = update_result_sort(type, m_u);
m_infer_result_universe = true;
}
expr local = mk_local(m_p.mk_fresh_name(), n, type, binder_info());
r.push_back(local);
map.insert(n, local);
}
m_tc = mk_type_checker(m_env, m_p.mk_ngen(), false);
}
/** \brief Traverse the introduction rule type and collect the universes where non-parameters reside in \c r_lvls.
// TODO(Leo): move to different file
static bool is_explicit(binder_info const & bi) {
return !bi.is_implicit() && !bi.is_strict_implicit() && !bi.is_inst_implicit();
}
/** \brief Replace every occurrences of the inductive datatypes (in decls) in \c type with a local constant */
expr fix_intro_rule_type(expr const & type, name_map<expr> const & ind_to_local) {
unsigned nparams = 0; // number of explicit parameters
for (expr const & param : m_params) {
if (is_explicit(local_info(param)))
nparams++;
}
return replace(type, [&](expr const & e) {
expr const & fn = get_app_fn(e);
if (!is_constant(fn))
return none_expr();
if (auto it = ind_to_local.find(const_name(fn))) {
buffer<expr> args;
get_app_args(e, args);
if (args.size() < nparams)
throw parser_error(sstream() << "invalide datatype declaration, "
<< "incorrect number of arguments to datatype '"
<< const_name(fn) << "'", m_p.pos_of(e));
pos_info pos = m_p.pos_of(e);
expr r = m_p.save_pos(copy(*it), pos);
for (unsigned i = nparams; i < args.size(); i++)
r = m_p.mk_app(r, args[i], pos);
return some_expr(r);
} else {
return none_expr();
}
});
}
void intro_rules_to_locals(buffer<inductive_decl> const & decls, name_map<expr> const & ind_to_local, buffer<expr> & r) {
for (inductive_decl const & decl : decls) {
for (intro_rule const & rule : inductive_decl_intros(decl)) {
expr type = fix_intro_rule_type(intro_rule_type(rule), ind_to_local);
expr local = mk_local(m_p.mk_fresh_name(), intro_rule_name(rule), type, binder_info());
r.push_back(local);
}
}
}
/** \brief Traverse the introduction rule type and collect the universes where arguments reside in \c r_lvls.
This information is used to compute the resultant universe level for the inductive datatype declaration.
*/
void accumulate_levels(expr intro_type, buffer<level> & r_lvls) {
unsigned i = 0;
while (is_pi(intro_type)) {
if (i >= m_num_params) {
expr s = m_tc->ensure_type(binding_domain(intro_type)).first;
level l = sort_level(s);
if (l == m_u) {
@ -503,85 +520,135 @@ struct inductive_cmd_fn {
} else if (std::find(r_lvls.begin(), r_lvls.end(), l) == r_lvls.end()) {
r_lvls.push_back(l);
}
}
intro_type = instantiate(binding_body(intro_type), mk_local_for(intro_type));
i++;
}
}
/** \brief Elaborate introduction rules and destructively update \c decls with the elaborated versions.
\remark This method is invoked only after all inductive datatype types have been elaborated and
inserted into the scratch environment m_env.
This method also store in r_lvls inferred levels that must be in the resultant universe.
/** \brief Given a sequence of introduction rules (encoded as local constants), compute the resultant
universe for the inductive datatype declaration.
*/
void elaborate_intro_rules(buffer<inductive_decl> & decls, buffer<level> & r_lvls) {
for (auto & d : decls) {
name d_name; expr d_type; list<intro_rule> d_intros;
std::tie(d_name, d_type, d_intros) = d;
buffer<intro_rule> new_irs;
for (auto const & ir : d_intros) {
name ir_name; expr ir_type;
std::tie(ir_name, ir_type) = ir;
level_param_names new_ls;
std::tie(ir_type, new_ls) = m_p.elaborate_at(m_env, ir_type);
for (auto l : new_ls) m_levels.push_back(l);
accumulate_levels(ir_type, r_lvls);
level infer_resultant_universe(unsigned num_intro_rules, expr const * intro_rules) {
lean_assert(m_infer_result_universe);
buffer<level> r_lvls;
for (unsigned i = 0; i < num_intro_rules; i++) {
accumulate_levels(mlocal_type(intro_rules[i]), r_lvls);
}
return mk_result_level(m_env, r_lvls);
}
/** \brief Create a mapping from inductive datatype temporary name (used in local constants) to an
application <tt>C.{ls} locals params</tt>, where \c C is the real name of the inductive datatype,
and \c ls are the universe level parameters in \c m_levels.
*/
name_map<expr> locals_to_inductive_types(buffer<expr> const & locals, unsigned nparams, expr const * params,
unsigned num_decls, expr const * decls) {
buffer<level> buffer_ls;
for (name const & l : m_levels) {
buffer_ls.push_back(mk_param_univ(l));
}
levels ls = to_list(buffer_ls.begin(), buffer_ls.end());
name_map<expr> r;
for (unsigned i = 0; i < num_decls; i++) {
expr c = mk_constant(local_pp_name(decls[i]), ls);
c = mk_app(c, locals);
c = mk_app(c, nparams, params);
r.insert(mlocal_name(decls[i]), c);
}
return r;
}
/** \brief Create the "final" introduction rule type. It will apply the mapping \c local_to_ind built using
locals_to_inductive_types, and abstract locals and parameters.
*/
expr mk_intro_rule_type(name const & ir_name,
buffer<expr> const & locals, unsigned nparams, expr const * params,
name_map<expr> const & local_to_ind, expr type) {
type = replace(type, [&](expr const & e) {
if (!is_local(e)) {
return none_expr();
} else if (auto it = local_to_ind.find(mlocal_name(e))) {
return some_expr(*it);
} else {
return none_expr();
}
});
type = Pi(nparams, params, type);
type = Pi(locals, type);
implicit_infer_kind k = get_implicit_infer_kind(ir_name);
switch (k) {
case implicit_infer_kind::Implicit: {
bool strict = true;
ir_type = infer_implicit(ir_type, m_num_params, strict);
break;
return infer_implicit(type, locals.size() + nparams, strict);
}
case implicit_infer_kind::RelaxedImplicit: {
bool strict = false;
ir_type = infer_implicit(ir_type, m_num_params, strict);
break;
return infer_implicit(type, locals.size() + nparams, strict);
}
case implicit_infer_kind::None:
break;
}
new_irs.push_back(intro_rule(ir_name, ir_type));
}
d = inductive_decl(d_name, d_type, to_list(new_irs.begin(), new_irs.end()));
return type;
}
lean_unreachable();
}
/** \brief If old_num_univ_params < m_levels.size(), then new universe params were collected when elaborating
the introduction rules. This method include them in all occurrences of the inductive datatype decls.
*/
void include_extra_univ_levels(buffer<inductive_decl> & decls, unsigned old_num_univ_params) {
if (m_levels.size() == old_num_univ_params)
return;
buffer<level> tmp;
for (auto l : m_levels) tmp.push_back(mk_param_univ(l));
levels new_ls = to_list(tmp.begin(), tmp.end());
for (auto & d : decls) {
/** \brief Elaborate inductive datatypes and their introduction rules. */
void elaborate_decls(buffer<inductive_decl> & decls, buffer<expr> const & locals) {
// We create an elaboration problem of the form
// Pi (params) (inductive_types) (intro_rules), Type
buffer<expr> to_elab;
to_elab.append(m_params);
name_map<expr> ind_to_local;
inductive_types_to_locals(decls, to_elab, ind_to_local);
intro_rules_to_locals(decls, ind_to_local, to_elab);
expr aux_type = Pi(to_elab, mk_Type(), m_p);
list<expr> locals_ctx;
for (expr const & local : locals)
locals_ctx = cons(local, locals_ctx);
level_param_names new_ls;
std::tie(aux_type, new_ls) = m_p.elaborate_type(aux_type, locals_ctx);
// save new levels
for (auto l : new_ls)
m_levels.push_back(l);
// update to_elab
for (expr & l : to_elab) {
l = update_mlocal(l, binding_domain(aux_type));
aux_type = instantiate(binding_body(aux_type), l);
}
unsigned nparams = m_params.size();
unsigned num_decls = decls.size();
unsigned first_intro_idx = nparams + num_decls;
lean_assert(first_intro_idx <= to_elab.size());
// compute resultant level
level resultant_level;
if (m_infer_result_universe) {
unsigned num_intros = to_elab.size() - first_intro_idx;
resultant_level = infer_resultant_universe(num_intros, to_elab.data() + first_intro_idx);
}
// update decls
unsigned i = nparams;
for (inductive_decl & decl : decls) {
expr type = mlocal_type(to_elab[i]);
if (m_infer_result_universe)
type = update_result_sort(type, resultant_level);
type = Pi(nparams, to_elab.data(), type);
type = Pi(locals, type);
decl = update_inductive_decl(decl, type);
i++;
}
// Create mapping for converting occurrences of inductive types (as local constants)
// into the real ones.
name_map<expr> local_to_ind = locals_to_inductive_types(locals,
nparams, to_elab.data(),
num_decls, to_elab.data() + nparams);
i = nparams + num_decls;
for (inductive_decl & decl : decls) {
buffer<intro_rule> new_irs;
for (auto & ir : inductive_decl_intros(d)) {
expr new_type = replace(intro_rule_type(ir), [&](expr const & e) {
if (!is_constant(e))
return none_expr();
if (!std::any_of(decls.begin(), decls.end(),
[&](inductive_decl const & d) { return const_name(e) == inductive_decl_name(d); }))
return none_expr();
// found target
expr r = update_constant(e, new_ls);
return some_expr(r);
});
new_irs.push_back(update_intro_rule(ir, new_type));
for (intro_rule const & ir : inductive_decl_intros(decl)) {
expr type = mlocal_type(to_elab[i]);
type = mk_intro_rule_type(intro_rule_name(ir), locals, nparams, to_elab.data(), local_to_ind, type);
new_irs.push_back(update_intro_rule(ir, type));
i++;
}
d = update_inductive_decl(d, new_irs);
}
}
/** \brief Update the resultant universe level of the inductive datatypes using the inferred universe \c r_lvl */
void update_resultant_universe(buffer<inductive_decl> & decls, level const & r_lvl) {
for (inductive_decl & d : decls) {
expr t = inductive_decl_type(d);
t = update_result_sort(t, r_lvl);
d = update_inductive_decl(d, t);
decl = update_inductive_decl(decl, new_irs);
}
}
@ -624,6 +691,23 @@ struct inductive_cmd_fn {
return env;
}
void update_declaration_index(environment const & env) {
name n, k; pos_info p;
for (auto const & info : m_decl_info) {
std::tie(n, k, p) = info;
expr type = env.get(n).get_type();
m_p.add_decl_index(n, p, k, type);
}
}
environment apply_modifiers(environment env) {
m_modifiers.for_each([&](name const & n, modifiers const & m) {
if (m.m_is_class)
env = add_class(env, n);
});
return env;
}
/** \brief Auxiliary method used for debugging */
void display(std::ostream & out, buffer<inductive_decl> const & decls) {
if (!m_levels.empty()) {
@ -644,23 +728,6 @@ struct inductive_cmd_fn {
out << "\n";
}
void update_declaration_index(environment const & env) {
name n, k; pos_info p;
for (auto const & info : m_decl_info) {
std::tie(n, k, p) = info;
expr type = env.get(n).get_type();
m_p.add_decl_index(n, p, k, type);
}
}
environment apply_modifiers(environment env) {
m_modifiers.for_each([&](name const & n, modifiers const & m) {
if (m.m_is_class)
env = add_class(env, n);
});
return env;
}
environment operator()() {
parser::no_undef_id_error_scope err_scope(m_p);
buffer<inductive_decl> decls;
@ -669,18 +736,9 @@ struct inductive_cmd_fn {
parse_inductive_decls(decls);
}
buffer<expr> locals;
abstract_locals(decls, locals);
include_local_levels(decls);
m_num_params += locals.size();
declare_inductive_types(decls);
unsigned num_univ_params = m_levels.size();
buffer<level> r_lvls;
elaborate_intro_rules(decls, r_lvls);
include_extra_univ_levels(decls, num_univ_params);
if (m_infer_result_universe) {
level r_lvl = mk_result_level(m_env, r_lvls);
update_resultant_universe(decls, r_lvl);
}
collect_locals(decls, locals);
include_local_levels(decls, locals);
elaborate_decls(decls, locals);
level_param_names ls = to_list(m_levels.begin(), m_levels.end());
environment env = module::add_inductive(m_p.env(), ls, m_num_params, to_list(decls.begin(), decls.end()));
update_declaration_index(env);