refactor(frontends/lean/inductive_cmd): redesign inductive datatype elaboration, use the new elaborator, and use simpler algorithm to infer the resulting universe
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
0adacb5191
commit
193ce35419
8 changed files with 122 additions and 87 deletions
|
@ -54,7 +54,7 @@ theorem or_elim (a b c : Bool) (H1 : a ∨ b) (H2 : a → c) (H3 : b → c) : c
|
||||||
:= or_rec H2 H3 H1
|
:= or_rec H2 H3 H1
|
||||||
|
|
||||||
inductive eq {A : Type} (a : A) : A → Bool :=
|
inductive eq {A : Type} (a : A) : A → Bool :=
|
||||||
| refl : eq A a a -- TODO: use elaborator in inductive_cmd module, we should not need to type A here
|
| refl : eq a a
|
||||||
|
|
||||||
infix `=` 50 := eq
|
infix `=` 50 := eq
|
||||||
|
|
||||||
|
|
|
@ -70,13 +70,17 @@ static level mk_result_level(bool impredicative, buffer<level> const & ls) {
|
||||||
level r = ls[0];
|
level r = ls[0];
|
||||||
for (unsigned i = 1; i < ls.size(); i++)
|
for (unsigned i = 1; i < ls.size(); i++)
|
||||||
r = mk_max(r, ls[i]);
|
r = mk_max(r, ls[i]);
|
||||||
return impredicative ? mk_max(r, mk_level_one()) : r;
|
if (is_not_zero(r))
|
||||||
|
return r;
|
||||||
|
else
|
||||||
|
return impredicative ? mk_max(r, mk_level_one()) : r;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static expr update_result_sort(expr const & t, level const & l) {
|
static expr update_result_sort(type_checker & tc, expr t, level const & l) {
|
||||||
|
t = tc.whnf(t);
|
||||||
if (is_pi(t)) {
|
if (is_pi(t)) {
|
||||||
return update_binding(t, binding_domain(t), update_result_sort(binding_body(t), l));
|
return update_binding(t, binding_domain(t), update_result_sort(tc, binding_body(t), l));
|
||||||
} else if (is_sort(t)) {
|
} else if (is_sort(t)) {
|
||||||
return update_sort(t, l);
|
return update_sort(t, l);
|
||||||
} else {
|
} else {
|
||||||
|
@ -84,75 +88,110 @@ static expr update_result_sort(expr const & t, level const & l) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** \brief Return the universe level of the given inductive datatype declaration. */
|
||||||
|
level get_datatype_result_level(type_checker & tc, inductive_decl const & d) {
|
||||||
|
expr d_t = tc.whnf(inductive_decl_type(d));
|
||||||
|
while (is_pi(d_t)) {
|
||||||
|
d_t = tc.whnf(binding_body(d_t));
|
||||||
|
}
|
||||||
|
if (!is_sort(d_t)) {
|
||||||
|
std::cout << "ERROR: " << inductive_decl_type(d) << "\n";
|
||||||
|
throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', "
|
||||||
|
"resultant type is not a sort");
|
||||||
|
}
|
||||||
|
return sort_level(d_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** \brief Return true if \c u occurs in \c l */
|
||||||
|
bool occurs(level const & u, level const & l) {
|
||||||
|
bool found = false;
|
||||||
|
for_each(l, [&](level const & l) {
|
||||||
|
if (found) return false;
|
||||||
|
if (l == u) { found = true; return false; }
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
return found;
|
||||||
|
}
|
||||||
|
|
||||||
static name g_tmp_prefix = name::mk_internal_unique_name();
|
static name g_tmp_prefix = name::mk_internal_unique_name();
|
||||||
static void set_result_universes(buffer<inductive_decl> & decls, level_param_names const & lvls, unsigned num_params, parser & p) {
|
/**
|
||||||
if (std::all_of(decls.begin(), decls.end(), [](inductive_decl const & d) {
|
\brief Given a type \c t for an introduction rule, store the universe of the types of non-parameters in \c ls.
|
||||||
return !has_placeholder(inductive_decl_type(d));
|
|
||||||
}))
|
\remark aux_u is an temporary universe used for inductive decls. It should be ignored.
|
||||||
return; // nothing to be done
|
*/
|
||||||
// We can't infer the type of intro rule arguments because we did declare the inductive datatypes.
|
static void accumulate_levels(type_checker & tc, expr t, unsigned num_params, level const & aux_u, buffer<level> & ls) {
|
||||||
// So, we use the following trick, we create a "draft" environment where the inductive datatypes
|
name_generator ngen(g_tmp_prefix);
|
||||||
// are asserted as variable declarations, and keep doing that until we reach a "fix" point.
|
unsigned i = 0;
|
||||||
unsigned num_rounds = 0;
|
while (is_pi(t)) {
|
||||||
while (true) {
|
if (i >= num_params) {
|
||||||
if (num_rounds > 2*decls.size() + 1) {
|
expr s = tc.ensure_type(binding_domain(t));
|
||||||
// TODO(Leo): this is test is a hack to avoid non-termination.
|
level l = sort_level(s);
|
||||||
// We should use a better termination condition
|
if (l == aux_u) {
|
||||||
throw exception("failed to compute resultant universe level for inductive datatypes, "
|
// ignore, this is the auxiliary level
|
||||||
"provide explicit universe levels");
|
} else if (occurs(aux_u, l)) {
|
||||||
}
|
throw exception("failed to infer inductive datatype resultant universe, provide the universe levels explicitly");
|
||||||
num_rounds++;
|
} else if (std::find(ls.begin(), ls.end(), l) == ls.end()) {
|
||||||
bool progress = false;
|
ls.push_back(l);
|
||||||
environment env = p.env();
|
|
||||||
bool impredicative = env.impredicative();
|
|
||||||
// first assert inductive types that do not have placeholders
|
|
||||||
for (auto const & d : decls) {
|
|
||||||
expr type = inductive_decl_type(d);
|
|
||||||
if (!has_placeholder(type))
|
|
||||||
env = env.add(check(env, mk_var_decl(inductive_decl_name(d), lvls, inductive_decl_type(d))));
|
|
||||||
}
|
|
||||||
type_checker tc(env);
|
|
||||||
name_generator ngen(g_tmp_prefix);
|
|
||||||
// try to update resultant universe levels
|
|
||||||
for (auto & d : decls) {
|
|
||||||
expr d_t = inductive_decl_type(d);
|
|
||||||
while (is_pi(d_t)) {
|
|
||||||
d_t = binding_body(d_t);
|
|
||||||
}
|
|
||||||
if (!is_sort(d_t))
|
|
||||||
throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', "
|
|
||||||
"resultant type is not a sort");
|
|
||||||
level r_lvl = sort_level(d_t);
|
|
||||||
if (impredicative && is_zero(r_lvl))
|
|
||||||
continue;
|
|
||||||
buffer<level> lvls;
|
|
||||||
for (intro_rule const & ir : inductive_decl_intros(d)) {
|
|
||||||
expr t = intro_rule_type(ir);
|
|
||||||
unsigned i = 0;
|
|
||||||
while (is_pi(t)) {
|
|
||||||
if (i >= num_params) {
|
|
||||||
try {
|
|
||||||
expr s = tc.ensure_type(binding_domain(t));
|
|
||||||
level lvl = sort_level(s);
|
|
||||||
if (std::find(lvls.begin(), lvls.end(), lvl) == lvls.end())
|
|
||||||
lvls.push_back(lvl);
|
|
||||||
} catch (...) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t = instantiate(binding_body(t), mk_local(ngen.next(), binding_name(t), binding_domain(t)));
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
level m_lvl = normalize(mk_result_level(impredicative, lvls));
|
|
||||||
if (is_placeholder(r_lvl) || !(is_geq(r_lvl, m_lvl))) {
|
|
||||||
progress = true;
|
|
||||||
// update result level
|
|
||||||
expr new_type = update_result_sort(inductive_decl_type(d), m_lvl);
|
|
||||||
d = inductive_decl(inductive_decl_name(d), new_type, inductive_decl_intros(d));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!progress)
|
t = instantiate(binding_body(t), mk_local(ngen.next(), binding_name(t), binding_domain(t)));
|
||||||
break;
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void throw_all_or_nothing() {
|
||||||
|
throw exception("invalid mutually recursive datatype declaration, "
|
||||||
|
"if the universe level of one type is provided, then all of them should be");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void elaborate_inductive(buffer<inductive_decl> & decls, level_param_names const & lvls, unsigned num_params, parser & p) {
|
||||||
|
// temporary environment used during elaboration
|
||||||
|
environment env = p.env();
|
||||||
|
// add fake universe level
|
||||||
|
name u_name(g_tmp_prefix, "u");
|
||||||
|
env = env.add_universe(u_name);
|
||||||
|
level u = mk_global_univ(u_name);
|
||||||
|
std::unique_ptr<type_checker> tc(new type_checker(env));
|
||||||
|
bool infer_result_universe = false;
|
||||||
|
unsigned first = true;
|
||||||
|
// elaborate inductive datatype types, and declare them in temporary environment.
|
||||||
|
for (inductive_decl & d : decls) {
|
||||||
|
level l = get_datatype_result_level(*tc, d);
|
||||||
|
expr t = inductive_decl_type(d);
|
||||||
|
if (is_placeholder(l)) {
|
||||||
|
if (first)
|
||||||
|
infer_result_universe = true;
|
||||||
|
else if (!infer_result_universe)
|
||||||
|
throw_all_or_nothing();
|
||||||
|
t = update_result_sort(*tc, t, u);
|
||||||
|
} else if (!first && infer_result_universe) {
|
||||||
|
throw_all_or_nothing();
|
||||||
|
}
|
||||||
|
t = p.elaborate(env, t);
|
||||||
|
env = env.add(check(env, mk_var_decl(inductive_decl_name(d), lvls, t)));
|
||||||
|
d = inductive_decl(inductive_decl_name(d), t, inductive_decl_intros(d));
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
tc.reset(new type_checker(env));
|
||||||
|
buffer<level> r_lvls; // used for inferring the universe level of resultant datatypes.
|
||||||
|
// elaborate introduction rules using temporary environment
|
||||||
|
for (inductive_decl & d : decls) {
|
||||||
|
buffer<intro_rule> intros;
|
||||||
|
for (intro_rule const & ir : inductive_decl_intros(d)) {
|
||||||
|
expr t = p.elaborate(env, intro_rule_type(ir));
|
||||||
|
if (infer_result_universe)
|
||||||
|
accumulate_levels(*tc, t, num_params, u, r_lvls);
|
||||||
|
intros.push_back(intro_rule(intro_rule_name(ir), t));
|
||||||
|
}
|
||||||
|
d = inductive_decl(inductive_decl_name(d), inductive_decl_type(d), to_list(intros.begin(), intros.end()));
|
||||||
|
}
|
||||||
|
if (infer_result_universe) {
|
||||||
|
level r_lvl = normalize(mk_result_level(env.impredicative(), r_lvls));
|
||||||
|
for (inductive_decl & d : decls) {
|
||||||
|
expr t = inductive_decl_type(d);
|
||||||
|
t = update_result_sort(*tc, t, r_lvl);
|
||||||
|
d = inductive_decl(inductive_decl_name(d), t, inductive_decl_intros(d));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -319,16 +358,7 @@ environment inductive_cmd(parser & p) {
|
||||||
num_params += section_params.size();
|
num_params += section_params.size();
|
||||||
level_param_names ls = to_list(ls_buffer.begin(), ls_buffer.end());
|
level_param_names ls = to_list(ls_buffer.begin(), ls_buffer.end());
|
||||||
|
|
||||||
// Check if introduction rules do not have placeholders
|
elaborate_inductive(decls, ls, num_params, p);
|
||||||
for (inductive_decl const & d : decls) {
|
|
||||||
for (auto const & ir : inductive_decl_intros(d)) {
|
|
||||||
if (has_placeholder(intro_rule_type(ir)))
|
|
||||||
throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', "
|
|
||||||
<< "introduction rule '" << intro_rule_name(ir) << "' has placeholders");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// "Fix" the inductive type resultant type universe level, if it was not explicitly provided.
|
|
||||||
set_result_universes(decls, ls, num_params, p);
|
|
||||||
env = module::add_inductive(env, ls, num_params, to_list(decls.begin(), decls.end()));
|
env = module::add_inductive(env, ls, num_params, to_list(decls.begin(), decls.end()));
|
||||||
// Create aliases/local refs
|
// Create aliases/local refs
|
||||||
levels section_levels = collect_section_levels(ls, p);
|
levels section_levels = collect_section_levels(ls, p);
|
||||||
|
|
|
@ -549,6 +549,10 @@ expr parser::elaborate(expr const & e) {
|
||||||
return ::lean::elaborate(m_env, m_ios, e);
|
return ::lean::elaborate(m_env, m_ios, e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expr parser::elaborate(environment const & env, expr const & e) {
|
||||||
|
return ::lean::elaborate(env, m_ios, e);
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<expr, expr> parser::elaborate(name const & n, expr const & t, expr const & v) {
|
std::pair<expr, expr> parser::elaborate(name const & n, expr const & t, expr const & v) {
|
||||||
return ::lean::elaborate(m_env, m_ios, n, t, v);
|
return ::lean::elaborate(m_env, m_ios, n, t, v);
|
||||||
}
|
}
|
||||||
|
|
|
@ -263,6 +263,7 @@ public:
|
||||||
struct no_undef_id_error_scope { parser & m_p; bool m_old; no_undef_id_error_scope(parser &); ~no_undef_id_error_scope(); };
|
struct no_undef_id_error_scope { parser & m_p; bool m_old; no_undef_id_error_scope(parser &); ~no_undef_id_error_scope(); };
|
||||||
|
|
||||||
expr elaborate(expr const & e);
|
expr elaborate(expr const & e);
|
||||||
|
expr elaborate(environment const & env, expr const & e);
|
||||||
std::pair<expr, expr> elaborate(name const & n, expr const & t, expr const & v);
|
std::pair<expr, expr> elaborate(name const & n, expr const & t, expr const & v);
|
||||||
|
|
||||||
/** parse all commands in the input stream */
|
/** parse all commands in the input stream */
|
||||||
|
|
|
@ -81,6 +81,8 @@ public:
|
||||||
type_checker(environment const & env);
|
type_checker(environment const & env);
|
||||||
~type_checker();
|
~type_checker();
|
||||||
|
|
||||||
|
environment const & env() const { return m_env; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
\brief Return the type of \c t.
|
\brief Return the type of \c t.
|
||||||
|
|
||||||
|
|
|
@ -139,7 +139,7 @@ std::pair<unify_status, substitution> unify_simple(substitution const & s, expr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return true if m occurs in e
|
// Return true if m occurs in e
|
||||||
bool occurs(level const & m, level const & e) {
|
bool occurs_meta(level const & m, level const & e) {
|
||||||
lean_assert(is_meta(m));
|
lean_assert(is_meta(m));
|
||||||
bool contains = false;
|
bool contains = false;
|
||||||
for_each(e, [&](level const & l) {
|
for_each(e, [&](level const & l) {
|
||||||
|
@ -156,7 +156,7 @@ bool occurs(level const & m, level const & e) {
|
||||||
|
|
||||||
std::pair<unify_status, substitution> unify_simple_core(substitution const & s, level const & lhs, level const & rhs, justification const & j) {
|
std::pair<unify_status, substitution> unify_simple_core(substitution const & s, level const & lhs, level const & rhs, justification const & j) {
|
||||||
lean_assert(is_meta(lhs));
|
lean_assert(is_meta(lhs));
|
||||||
bool contains = occurs(lhs, rhs);
|
bool contains = occurs_meta(lhs, rhs);
|
||||||
if (contains) {
|
if (contains) {
|
||||||
if (is_succ(rhs))
|
if (is_succ(rhs))
|
||||||
return mk_pair(unify_status::Failed, s);
|
return mk_pair(unify_status::Failed, s);
|
||||||
|
@ -620,7 +620,7 @@ struct unifier_fn {
|
||||||
status process_metavar_eq(level const & lhs, level const & rhs, justification const & j) {
|
status process_metavar_eq(level const & lhs, level const & rhs, justification const & j) {
|
||||||
if (!is_meta(lhs))
|
if (!is_meta(lhs))
|
||||||
return Continue;
|
return Continue;
|
||||||
bool contains = occurs(lhs, rhs);
|
bool contains = occurs_meta(lhs, rhs);
|
||||||
if (contains) {
|
if (contains) {
|
||||||
if (is_succ(rhs))
|
if (is_succ(rhs))
|
||||||
return Failed;
|
return Failed;
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
inductive tree (A : Type) : Type :=
|
inductive tree (A : Type) : Type :=
|
||||||
| node : A → forest A → tree A
|
| node : A → forest A → tree A
|
||||||
with forest {A : Type} : Type :=
|
with forest (A : Type) : Type :=
|
||||||
| nil : forest A
|
| nil : forest A
|
||||||
| cons : tree A → forest A → forest A
|
| cons : tree A → forest A → forest A
|
||||||
|
|
||||||
|
@ -17,5 +17,3 @@ inductive group : Type :=
|
||||||
check group.{1}
|
check group.{1}
|
||||||
check group.{2}
|
check group.{2}
|
||||||
check group_rec.{1 1}
|
check group_rec.{1 1}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
inductive tree.{u} (A : Type.{u}) : Type.{max u 1} :=
|
inductive tree.{u} (A : Type.{u}) : Type.{max u 1} :=
|
||||||
| node : A → forest.{u} A → tree.{u} A
|
| node : A → forest.{u} A → tree.{u} A
|
||||||
with forest.{u} {A : Type.{u}} : Type.{max u 1} :=
|
with forest.{u} (A : Type.{u}) : Type.{max u 1} :=
|
||||||
| nil : forest.{u} A
|
| nil : forest.{u} A
|
||||||
| cons : tree.{u} A → forest.{u} A → forest.{u} A
|
| cons : tree.{u} A → forest.{u} A → forest.{u} A
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue