refactor(frontends/lean/inductive_cmd): redesign inductive datatype elaboration, use the new elaborator, and use simpler algorithm to infer the resulting universe
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
0adacb5191
commit
193ce35419
8 changed files with 122 additions and 87 deletions
|
@ -54,7 +54,7 @@ theorem or_elim (a b c : Bool) (H1 : a ∨ b) (H2 : a → c) (H3 : b → c) : c
|
|||
:= or_rec H2 H3 H1
|
||||
|
||||
inductive eq {A : Type} (a : A) : A → Bool :=
|
||||
| refl : eq A a a -- TODO: use elaborator in inductive_cmd module, we should not need to type A here
|
||||
| refl : eq a a
|
||||
|
||||
infix `=` 50 := eq
|
||||
|
||||
|
|
|
@ -70,13 +70,17 @@ static level mk_result_level(bool impredicative, buffer<level> const & ls) {
|
|||
level r = ls[0];
|
||||
for (unsigned i = 1; i < ls.size(); i++)
|
||||
r = mk_max(r, ls[i]);
|
||||
return impredicative ? mk_max(r, mk_level_one()) : r;
|
||||
if (is_not_zero(r))
|
||||
return r;
|
||||
else
|
||||
return impredicative ? mk_max(r, mk_level_one()) : r;
|
||||
}
|
||||
}
|
||||
|
||||
static expr update_result_sort(expr const & t, level const & l) {
|
||||
static expr update_result_sort(type_checker & tc, expr t, level const & l) {
|
||||
t = tc.whnf(t);
|
||||
if (is_pi(t)) {
|
||||
return update_binding(t, binding_domain(t), update_result_sort(binding_body(t), l));
|
||||
return update_binding(t, binding_domain(t), update_result_sort(tc, binding_body(t), l));
|
||||
} else if (is_sort(t)) {
|
||||
return update_sort(t, l);
|
||||
} else {
|
||||
|
@ -84,75 +88,110 @@ static expr update_result_sort(expr const & t, level const & l) {
|
|||
}
|
||||
}
|
||||
|
||||
/** \brief Return the universe level of the given inductive datatype declaration. */
|
||||
level get_datatype_result_level(type_checker & tc, inductive_decl const & d) {
|
||||
expr d_t = tc.whnf(inductive_decl_type(d));
|
||||
while (is_pi(d_t)) {
|
||||
d_t = tc.whnf(binding_body(d_t));
|
||||
}
|
||||
if (!is_sort(d_t)) {
|
||||
std::cout << "ERROR: " << inductive_decl_type(d) << "\n";
|
||||
throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', "
|
||||
"resultant type is not a sort");
|
||||
}
|
||||
return sort_level(d_t);
|
||||
}
|
||||
|
||||
/** \brief Return true if \c u occurs in \c l */
|
||||
bool occurs(level const & u, level const & l) {
|
||||
bool found = false;
|
||||
for_each(l, [&](level const & l) {
|
||||
if (found) return false;
|
||||
if (l == u) { found = true; return false; }
|
||||
return true;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
static name g_tmp_prefix = name::mk_internal_unique_name();
|
||||
static void set_result_universes(buffer<inductive_decl> & decls, level_param_names const & lvls, unsigned num_params, parser & p) {
|
||||
if (std::all_of(decls.begin(), decls.end(), [](inductive_decl const & d) {
|
||||
return !has_placeholder(inductive_decl_type(d));
|
||||
}))
|
||||
return; // nothing to be done
|
||||
// We can't infer the type of intro rule arguments because we did declare the inductive datatypes.
|
||||
// So, we use the following trick, we create a "draft" environment where the inductive datatypes
|
||||
// are asserted as variable declarations, and keep doing that until we reach a "fix" point.
|
||||
unsigned num_rounds = 0;
|
||||
while (true) {
|
||||
if (num_rounds > 2*decls.size() + 1) {
|
||||
// TODO(Leo): this is test is a hack to avoid non-termination.
|
||||
// We should use a better termination condition
|
||||
throw exception("failed to compute resultant universe level for inductive datatypes, "
|
||||
"provide explicit universe levels");
|
||||
}
|
||||
num_rounds++;
|
||||
bool progress = false;
|
||||
environment env = p.env();
|
||||
bool impredicative = env.impredicative();
|
||||
// first assert inductive types that do not have placeholders
|
||||
for (auto const & d : decls) {
|
||||
expr type = inductive_decl_type(d);
|
||||
if (!has_placeholder(type))
|
||||
env = env.add(check(env, mk_var_decl(inductive_decl_name(d), lvls, inductive_decl_type(d))));
|
||||
}
|
||||
type_checker tc(env);
|
||||
name_generator ngen(g_tmp_prefix);
|
||||
// try to update resultant universe levels
|
||||
for (auto & d : decls) {
|
||||
expr d_t = inductive_decl_type(d);
|
||||
while (is_pi(d_t)) {
|
||||
d_t = binding_body(d_t);
|
||||
}
|
||||
if (!is_sort(d_t))
|
||||
throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', "
|
||||
"resultant type is not a sort");
|
||||
level r_lvl = sort_level(d_t);
|
||||
if (impredicative && is_zero(r_lvl))
|
||||
continue;
|
||||
buffer<level> lvls;
|
||||
for (intro_rule const & ir : inductive_decl_intros(d)) {
|
||||
expr t = intro_rule_type(ir);
|
||||
unsigned i = 0;
|
||||
while (is_pi(t)) {
|
||||
if (i >= num_params) {
|
||||
try {
|
||||
expr s = tc.ensure_type(binding_domain(t));
|
||||
level lvl = sort_level(s);
|
||||
if (std::find(lvls.begin(), lvls.end(), lvl) == lvls.end())
|
||||
lvls.push_back(lvl);
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
t = instantiate(binding_body(t), mk_local(ngen.next(), binding_name(t), binding_domain(t)));
|
||||
i++;
|
||||
}
|
||||
}
|
||||
level m_lvl = normalize(mk_result_level(impredicative, lvls));
|
||||
if (is_placeholder(r_lvl) || !(is_geq(r_lvl, m_lvl))) {
|
||||
progress = true;
|
||||
// update result level
|
||||
expr new_type = update_result_sort(inductive_decl_type(d), m_lvl);
|
||||
d = inductive_decl(inductive_decl_name(d), new_type, inductive_decl_intros(d));
|
||||
/**
|
||||
\brief Given a type \c t for an introduction rule, store the universe of the types of non-parameters in \c ls.
|
||||
|
||||
\remark aux_u is an temporary universe used for inductive decls. It should be ignored.
|
||||
*/
|
||||
static void accumulate_levels(type_checker & tc, expr t, unsigned num_params, level const & aux_u, buffer<level> & ls) {
|
||||
name_generator ngen(g_tmp_prefix);
|
||||
unsigned i = 0;
|
||||
while (is_pi(t)) {
|
||||
if (i >= num_params) {
|
||||
expr s = tc.ensure_type(binding_domain(t));
|
||||
level l = sort_level(s);
|
||||
if (l == aux_u) {
|
||||
// ignore, this is the auxiliary level
|
||||
} else if (occurs(aux_u, l)) {
|
||||
throw exception("failed to infer inductive datatype resultant universe, provide the universe levels explicitly");
|
||||
} else if (std::find(ls.begin(), ls.end(), l) == ls.end()) {
|
||||
ls.push_back(l);
|
||||
}
|
||||
}
|
||||
if (!progress)
|
||||
break;
|
||||
t = instantiate(binding_body(t), mk_local(ngen.next(), binding_name(t), binding_domain(t)));
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
void throw_all_or_nothing() {
|
||||
throw exception("invalid mutually recursive datatype declaration, "
|
||||
"if the universe level of one type is provided, then all of them should be");
|
||||
}
|
||||
|
||||
static void elaborate_inductive(buffer<inductive_decl> & decls, level_param_names const & lvls, unsigned num_params, parser & p) {
|
||||
// temporary environment used during elaboration
|
||||
environment env = p.env();
|
||||
// add fake universe level
|
||||
name u_name(g_tmp_prefix, "u");
|
||||
env = env.add_universe(u_name);
|
||||
level u = mk_global_univ(u_name);
|
||||
std::unique_ptr<type_checker> tc(new type_checker(env));
|
||||
bool infer_result_universe = false;
|
||||
unsigned first = true;
|
||||
// elaborate inductive datatype types, and declare them in temporary environment.
|
||||
for (inductive_decl & d : decls) {
|
||||
level l = get_datatype_result_level(*tc, d);
|
||||
expr t = inductive_decl_type(d);
|
||||
if (is_placeholder(l)) {
|
||||
if (first)
|
||||
infer_result_universe = true;
|
||||
else if (!infer_result_universe)
|
||||
throw_all_or_nothing();
|
||||
t = update_result_sort(*tc, t, u);
|
||||
} else if (!first && infer_result_universe) {
|
||||
throw_all_or_nothing();
|
||||
}
|
||||
t = p.elaborate(env, t);
|
||||
env = env.add(check(env, mk_var_decl(inductive_decl_name(d), lvls, t)));
|
||||
d = inductive_decl(inductive_decl_name(d), t, inductive_decl_intros(d));
|
||||
first = false;
|
||||
}
|
||||
tc.reset(new type_checker(env));
|
||||
buffer<level> r_lvls; // used for inferring the universe level of resultant datatypes.
|
||||
// elaborate introduction rules using temporary environment
|
||||
for (inductive_decl & d : decls) {
|
||||
buffer<intro_rule> intros;
|
||||
for (intro_rule const & ir : inductive_decl_intros(d)) {
|
||||
expr t = p.elaborate(env, intro_rule_type(ir));
|
||||
if (infer_result_universe)
|
||||
accumulate_levels(*tc, t, num_params, u, r_lvls);
|
||||
intros.push_back(intro_rule(intro_rule_name(ir), t));
|
||||
}
|
||||
d = inductive_decl(inductive_decl_name(d), inductive_decl_type(d), to_list(intros.begin(), intros.end()));
|
||||
}
|
||||
if (infer_result_universe) {
|
||||
level r_lvl = normalize(mk_result_level(env.impredicative(), r_lvls));
|
||||
for (inductive_decl & d : decls) {
|
||||
expr t = inductive_decl_type(d);
|
||||
t = update_result_sort(*tc, t, r_lvl);
|
||||
d = inductive_decl(inductive_decl_name(d), t, inductive_decl_intros(d));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -319,16 +358,7 @@ environment inductive_cmd(parser & p) {
|
|||
num_params += section_params.size();
|
||||
level_param_names ls = to_list(ls_buffer.begin(), ls_buffer.end());
|
||||
|
||||
// Check if introduction rules do not have placeholders
|
||||
for (inductive_decl const & d : decls) {
|
||||
for (auto const & ir : inductive_decl_intros(d)) {
|
||||
if (has_placeholder(intro_rule_type(ir)))
|
||||
throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', "
|
||||
<< "introduction rule '" << intro_rule_name(ir) << "' has placeholders");
|
||||
}
|
||||
}
|
||||
// "Fix" the inductive type resultant type universe level, if it was not explicitly provided.
|
||||
set_result_universes(decls, ls, num_params, p);
|
||||
elaborate_inductive(decls, ls, num_params, p);
|
||||
env = module::add_inductive(env, ls, num_params, to_list(decls.begin(), decls.end()));
|
||||
// Create aliases/local refs
|
||||
levels section_levels = collect_section_levels(ls, p);
|
||||
|
|
|
@ -549,6 +549,10 @@ expr parser::elaborate(expr const & e) {
|
|||
return ::lean::elaborate(m_env, m_ios, e);
|
||||
}
|
||||
|
||||
expr parser::elaborate(environment const & env, expr const & e) {
|
||||
return ::lean::elaborate(env, m_ios, e);
|
||||
}
|
||||
|
||||
std::pair<expr, expr> parser::elaborate(name const & n, expr const & t, expr const & v) {
|
||||
return ::lean::elaborate(m_env, m_ios, n, t, v);
|
||||
}
|
||||
|
|
|
@ -263,6 +263,7 @@ public:
|
|||
struct no_undef_id_error_scope { parser & m_p; bool m_old; no_undef_id_error_scope(parser &); ~no_undef_id_error_scope(); };
|
||||
|
||||
expr elaborate(expr const & e);
|
||||
expr elaborate(environment const & env, expr const & e);
|
||||
std::pair<expr, expr> elaborate(name const & n, expr const & t, expr const & v);
|
||||
|
||||
/** parse all commands in the input stream */
|
||||
|
|
|
@ -81,6 +81,8 @@ public:
|
|||
type_checker(environment const & env);
|
||||
~type_checker();
|
||||
|
||||
environment const & env() const { return m_env; }
|
||||
|
||||
/**
|
||||
\brief Return the type of \c t.
|
||||
|
||||
|
|
|
@ -139,7 +139,7 @@ std::pair<unify_status, substitution> unify_simple(substitution const & s, expr
|
|||
}
|
||||
|
||||
// Return true if m occurs in e
|
||||
bool occurs(level const & m, level const & e) {
|
||||
bool occurs_meta(level const & m, level const & e) {
|
||||
lean_assert(is_meta(m));
|
||||
bool contains = false;
|
||||
for_each(e, [&](level const & l) {
|
||||
|
@ -156,7 +156,7 @@ bool occurs(level const & m, level const & e) {
|
|||
|
||||
std::pair<unify_status, substitution> unify_simple_core(substitution const & s, level const & lhs, level const & rhs, justification const & j) {
|
||||
lean_assert(is_meta(lhs));
|
||||
bool contains = occurs(lhs, rhs);
|
||||
bool contains = occurs_meta(lhs, rhs);
|
||||
if (contains) {
|
||||
if (is_succ(rhs))
|
||||
return mk_pair(unify_status::Failed, s);
|
||||
|
@ -620,7 +620,7 @@ struct unifier_fn {
|
|||
status process_metavar_eq(level const & lhs, level const & rhs, justification const & j) {
|
||||
if (!is_meta(lhs))
|
||||
return Continue;
|
||||
bool contains = occurs(lhs, rhs);
|
||||
bool contains = occurs_meta(lhs, rhs);
|
||||
if (contains) {
|
||||
if (is_succ(rhs))
|
||||
return Failed;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
inductive tree (A : Type) : Type :=
|
||||
| node : A → forest A → tree A
|
||||
with forest {A : Type} : Type :=
|
||||
with forest (A : Type) : Type :=
|
||||
| nil : forest A
|
||||
| cons : tree A → forest A → forest A
|
||||
|
||||
|
@ -17,5 +17,3 @@ inductive group : Type :=
|
|||
check group.{1}
|
||||
check group.{2}
|
||||
check group_rec.{1 1}
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
inductive tree.{u} (A : Type.{u}) : Type.{max u 1} :=
|
||||
| node : A → forest.{u} A → tree.{u} A
|
||||
with forest.{u} {A : Type.{u}} : Type.{max u 1} :=
|
||||
with forest.{u} (A : Type.{u}) : Type.{max u 1} :=
|
||||
| nil : forest.{u} A
|
||||
| cons : tree.{u} A → forest.{u} A → forest.{u} A
|
||||
|
||||
|
|
Loading…
Reference in a new issue