refactor(frontends/lean/inductive_cmd): redesign inductive datatype elaboration, use the new elaborator, and use simpler algorithm to infer the resulting universe

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-28 15:33:56 -07:00
parent 0adacb5191
commit 193ce35419
8 changed files with 122 additions and 87 deletions

View file

@ -54,7 +54,7 @@ theorem or_elim (a b c : Bool) (H1 : a b) (H2 : a → c) (H3 : b → c) : c
:= or_rec H2 H3 H1 := or_rec H2 H3 H1
inductive eq {A : Type} (a : A) : A → Bool := inductive eq {A : Type} (a : A) : A → Bool :=
| refl : eq A a a -- TODO: use elaborator in inductive_cmd module, we should not need to type A here | refl : eq a a
infix `=` 50 := eq infix `=` 50 := eq

View file

@ -70,13 +70,17 @@ static level mk_result_level(bool impredicative, buffer<level> const & ls) {
level r = ls[0]; level r = ls[0];
for (unsigned i = 1; i < ls.size(); i++) for (unsigned i = 1; i < ls.size(); i++)
r = mk_max(r, ls[i]); r = mk_max(r, ls[i]);
return impredicative ? mk_max(r, mk_level_one()) : r; if (is_not_zero(r))
return r;
else
return impredicative ? mk_max(r, mk_level_one()) : r;
} }
} }
static expr update_result_sort(expr const & t, level const & l) { static expr update_result_sort(type_checker & tc, expr t, level const & l) {
t = tc.whnf(t);
if (is_pi(t)) { if (is_pi(t)) {
return update_binding(t, binding_domain(t), update_result_sort(binding_body(t), l)); return update_binding(t, binding_domain(t), update_result_sort(tc, binding_body(t), l));
} else if (is_sort(t)) { } else if (is_sort(t)) {
return update_sort(t, l); return update_sort(t, l);
} else { } else {
@ -84,75 +88,110 @@ static expr update_result_sort(expr const & t, level const & l) {
} }
} }
/** \brief Return the universe level of the given inductive datatype declaration. */
level get_datatype_result_level(type_checker & tc, inductive_decl const & d) {
expr d_t = tc.whnf(inductive_decl_type(d));
while (is_pi(d_t)) {
d_t = tc.whnf(binding_body(d_t));
}
if (!is_sort(d_t)) {
std::cout << "ERROR: " << inductive_decl_type(d) << "\n";
throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', "
"resultant type is not a sort");
}
return sort_level(d_t);
}
/** \brief Return true if \c u occurs in \c l */
bool occurs(level const & u, level const & l) {
bool found = false;
for_each(l, [&](level const & l) {
if (found) return false;
if (l == u) { found = true; return false; }
return true;
});
return found;
}
static name g_tmp_prefix = name::mk_internal_unique_name(); static name g_tmp_prefix = name::mk_internal_unique_name();
static void set_result_universes(buffer<inductive_decl> & decls, level_param_names const & lvls, unsigned num_params, parser & p) { /**
if (std::all_of(decls.begin(), decls.end(), [](inductive_decl const & d) { \brief Given a type \c t for an introduction rule, store the universe of the types of non-parameters in \c ls.
return !has_placeholder(inductive_decl_type(d));
})) \remark aux_u is an temporary universe used for inductive decls. It should be ignored.
return; // nothing to be done */
// We can't infer the type of intro rule arguments because we did declare the inductive datatypes. static void accumulate_levels(type_checker & tc, expr t, unsigned num_params, level const & aux_u, buffer<level> & ls) {
// So, we use the following trick, we create a "draft" environment where the inductive datatypes name_generator ngen(g_tmp_prefix);
// are asserted as variable declarations, and keep doing that until we reach a "fix" point. unsigned i = 0;
unsigned num_rounds = 0; while (is_pi(t)) {
while (true) { if (i >= num_params) {
if (num_rounds > 2*decls.size() + 1) { expr s = tc.ensure_type(binding_domain(t));
// TODO(Leo): this is test is a hack to avoid non-termination. level l = sort_level(s);
// We should use a better termination condition if (l == aux_u) {
throw exception("failed to compute resultant universe level for inductive datatypes, " // ignore, this is the auxiliary level
"provide explicit universe levels"); } else if (occurs(aux_u, l)) {
} throw exception("failed to infer inductive datatype resultant universe, provide the universe levels explicitly");
num_rounds++; } else if (std::find(ls.begin(), ls.end(), l) == ls.end()) {
bool progress = false; ls.push_back(l);
environment env = p.env();
bool impredicative = env.impredicative();
// first assert inductive types that do not have placeholders
for (auto const & d : decls) {
expr type = inductive_decl_type(d);
if (!has_placeholder(type))
env = env.add(check(env, mk_var_decl(inductive_decl_name(d), lvls, inductive_decl_type(d))));
}
type_checker tc(env);
name_generator ngen(g_tmp_prefix);
// try to update resultant universe levels
for (auto & d : decls) {
expr d_t = inductive_decl_type(d);
while (is_pi(d_t)) {
d_t = binding_body(d_t);
}
if (!is_sort(d_t))
throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', "
"resultant type is not a sort");
level r_lvl = sort_level(d_t);
if (impredicative && is_zero(r_lvl))
continue;
buffer<level> lvls;
for (intro_rule const & ir : inductive_decl_intros(d)) {
expr t = intro_rule_type(ir);
unsigned i = 0;
while (is_pi(t)) {
if (i >= num_params) {
try {
expr s = tc.ensure_type(binding_domain(t));
level lvl = sort_level(s);
if (std::find(lvls.begin(), lvls.end(), lvl) == lvls.end())
lvls.push_back(lvl);
} catch (...) {
}
}
t = instantiate(binding_body(t), mk_local(ngen.next(), binding_name(t), binding_domain(t)));
i++;
}
}
level m_lvl = normalize(mk_result_level(impredicative, lvls));
if (is_placeholder(r_lvl) || !(is_geq(r_lvl, m_lvl))) {
progress = true;
// update result level
expr new_type = update_result_sort(inductive_decl_type(d), m_lvl);
d = inductive_decl(inductive_decl_name(d), new_type, inductive_decl_intros(d));
} }
} }
if (!progress) t = instantiate(binding_body(t), mk_local(ngen.next(), binding_name(t), binding_domain(t)));
break; i++;
}
}
void throw_all_or_nothing() {
throw exception("invalid mutually recursive datatype declaration, "
"if the universe level of one type is provided, then all of them should be");
}
static void elaborate_inductive(buffer<inductive_decl> & decls, level_param_names const & lvls, unsigned num_params, parser & p) {
// temporary environment used during elaboration
environment env = p.env();
// add fake universe level
name u_name(g_tmp_prefix, "u");
env = env.add_universe(u_name);
level u = mk_global_univ(u_name);
std::unique_ptr<type_checker> tc(new type_checker(env));
bool infer_result_universe = false;
unsigned first = true;
// elaborate inductive datatype types, and declare them in temporary environment.
for (inductive_decl & d : decls) {
level l = get_datatype_result_level(*tc, d);
expr t = inductive_decl_type(d);
if (is_placeholder(l)) {
if (first)
infer_result_universe = true;
else if (!infer_result_universe)
throw_all_or_nothing();
t = update_result_sort(*tc, t, u);
} else if (!first && infer_result_universe) {
throw_all_or_nothing();
}
t = p.elaborate(env, t);
env = env.add(check(env, mk_var_decl(inductive_decl_name(d), lvls, t)));
d = inductive_decl(inductive_decl_name(d), t, inductive_decl_intros(d));
first = false;
}
tc.reset(new type_checker(env));
buffer<level> r_lvls; // used for inferring the universe level of resultant datatypes.
// elaborate introduction rules using temporary environment
for (inductive_decl & d : decls) {
buffer<intro_rule> intros;
for (intro_rule const & ir : inductive_decl_intros(d)) {
expr t = p.elaborate(env, intro_rule_type(ir));
if (infer_result_universe)
accumulate_levels(*tc, t, num_params, u, r_lvls);
intros.push_back(intro_rule(intro_rule_name(ir), t));
}
d = inductive_decl(inductive_decl_name(d), inductive_decl_type(d), to_list(intros.begin(), intros.end()));
}
if (infer_result_universe) {
level r_lvl = normalize(mk_result_level(env.impredicative(), r_lvls));
for (inductive_decl & d : decls) {
expr t = inductive_decl_type(d);
t = update_result_sort(*tc, t, r_lvl);
d = inductive_decl(inductive_decl_name(d), t, inductive_decl_intros(d));
}
} }
} }
@ -319,16 +358,7 @@ environment inductive_cmd(parser & p) {
num_params += section_params.size(); num_params += section_params.size();
level_param_names ls = to_list(ls_buffer.begin(), ls_buffer.end()); level_param_names ls = to_list(ls_buffer.begin(), ls_buffer.end());
// Check if introduction rules do not have placeholders elaborate_inductive(decls, ls, num_params, p);
for (inductive_decl const & d : decls) {
for (auto const & ir : inductive_decl_intros(d)) {
if (has_placeholder(intro_rule_type(ir)))
throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', "
<< "introduction rule '" << intro_rule_name(ir) << "' has placeholders");
}
}
// "Fix" the inductive type resultant type universe level, if it was not explicitly provided.
set_result_universes(decls, ls, num_params, p);
env = module::add_inductive(env, ls, num_params, to_list(decls.begin(), decls.end())); env = module::add_inductive(env, ls, num_params, to_list(decls.begin(), decls.end()));
// Create aliases/local refs // Create aliases/local refs
levels section_levels = collect_section_levels(ls, p); levels section_levels = collect_section_levels(ls, p);

View file

@ -549,6 +549,10 @@ expr parser::elaborate(expr const & e) {
return ::lean::elaborate(m_env, m_ios, e); return ::lean::elaborate(m_env, m_ios, e);
} }
expr parser::elaborate(environment const & env, expr const & e) {
return ::lean::elaborate(env, m_ios, e);
}
std::pair<expr, expr> parser::elaborate(name const & n, expr const & t, expr const & v) { std::pair<expr, expr> parser::elaborate(name const & n, expr const & t, expr const & v) {
return ::lean::elaborate(m_env, m_ios, n, t, v); return ::lean::elaborate(m_env, m_ios, n, t, v);
} }

View file

@ -263,6 +263,7 @@ public:
struct no_undef_id_error_scope { parser & m_p; bool m_old; no_undef_id_error_scope(parser &); ~no_undef_id_error_scope(); }; struct no_undef_id_error_scope { parser & m_p; bool m_old; no_undef_id_error_scope(parser &); ~no_undef_id_error_scope(); };
expr elaborate(expr const & e); expr elaborate(expr const & e);
expr elaborate(environment const & env, expr const & e);
std::pair<expr, expr> elaborate(name const & n, expr const & t, expr const & v); std::pair<expr, expr> elaborate(name const & n, expr const & t, expr const & v);
/** parse all commands in the input stream */ /** parse all commands in the input stream */

View file

@ -81,6 +81,8 @@ public:
type_checker(environment const & env); type_checker(environment const & env);
~type_checker(); ~type_checker();
environment const & env() const { return m_env; }
/** /**
\brief Return the type of \c t. \brief Return the type of \c t.

View file

@ -139,7 +139,7 @@ std::pair<unify_status, substitution> unify_simple(substitution const & s, expr
} }
// Return true if m occurs in e // Return true if m occurs in e
bool occurs(level const & m, level const & e) { bool occurs_meta(level const & m, level const & e) {
lean_assert(is_meta(m)); lean_assert(is_meta(m));
bool contains = false; bool contains = false;
for_each(e, [&](level const & l) { for_each(e, [&](level const & l) {
@ -156,7 +156,7 @@ bool occurs(level const & m, level const & e) {
std::pair<unify_status, substitution> unify_simple_core(substitution const & s, level const & lhs, level const & rhs, justification const & j) { std::pair<unify_status, substitution> unify_simple_core(substitution const & s, level const & lhs, level const & rhs, justification const & j) {
lean_assert(is_meta(lhs)); lean_assert(is_meta(lhs));
bool contains = occurs(lhs, rhs); bool contains = occurs_meta(lhs, rhs);
if (contains) { if (contains) {
if (is_succ(rhs)) if (is_succ(rhs))
return mk_pair(unify_status::Failed, s); return mk_pair(unify_status::Failed, s);
@ -620,7 +620,7 @@ struct unifier_fn {
status process_metavar_eq(level const & lhs, level const & rhs, justification const & j) { status process_metavar_eq(level const & lhs, level const & rhs, justification const & j) {
if (!is_meta(lhs)) if (!is_meta(lhs))
return Continue; return Continue;
bool contains = occurs(lhs, rhs); bool contains = occurs_meta(lhs, rhs);
if (contains) { if (contains) {
if (is_succ(rhs)) if (is_succ(rhs))
return Failed; return Failed;

View file

@ -1,6 +1,6 @@
inductive tree (A : Type) : Type := inductive tree (A : Type) : Type :=
| node : A → forest A → tree A | node : A → forest A → tree A
with forest {A : Type} : Type := with forest (A : Type) : Type :=
| nil : forest A | nil : forest A
| cons : tree A → forest A → forest A | cons : tree A → forest A → forest A
@ -17,5 +17,3 @@ inductive group : Type :=
check group.{1} check group.{1}
check group.{2} check group.{2}
check group_rec.{1 1} check group_rec.{1 1}

View file

@ -1,6 +1,6 @@
inductive tree.{u} (A : Type.{u}) : Type.{max u 1} := inductive tree.{u} (A : Type.{u}) : Type.{max u 1} :=
| node : A → forest.{u} A → tree.{u} A | node : A → forest.{u} A → tree.{u} A
with forest.{u} {A : Type.{u}} : Type.{max u 1} := with forest.{u} (A : Type.{u}) : Type.{max u 1} :=
| nil : forest.{u} A | nil : forest.{u} A
| cons : tree.{u} A → forest.{u} A → forest.{u} A | cons : tree.{u} A → forest.{u} A → forest.{u} A