feat(library/definitional/equations): mutually recursive functions for mutually recursive datatypes

This commit is contained in:
Leonardo de Moura 2015-01-06 14:07:17 -08:00
parent fb1cb3c623
commit a3a6697f44
4 changed files with 154 additions and 22 deletions

View file

@ -387,7 +387,7 @@ static void throw_invalid_equation_lhs(name const & n, pos_info const & p) {
<< n << "' in the left-hand-side does not correspond to function(s) being defined", p);
}
expr parse_equations(parser & p, name const & n, expr const & type, buffer<expr> & auxs,
expr parse_equations(parser & p, name const & n, expr const & type, buffer<name> & auxs,
optional<local_environment> const & lenv, buffer<expr> const & ps) {
buffer<expr> eqns;
buffer<expr> fns;
@ -404,7 +404,7 @@ expr parse_equations(parser & p, name const & n, expr const & type, buffer<expr>
p.check_token_next(get_colon_tk(), "invalid declaration, ':' expected");
expr g_type = p.parse_expr();
expr g = mk_local(g_name, g_type);
auxs.push_back(g);
auxs.push_back(g_name);
fns.push_back(g);
}
}
@ -475,8 +475,9 @@ class definition_cmd_fn {
decl_modifiers m_modifiers;
name m_real_name; // real name for this declaration
buffer<name> m_ls_buffer;
buffer<expr> m_aux_decls;
buffer<name> m_aux_decls; // user provided names for aux_decls
buffer<name> m_real_aux_names; // real names for aux_decls
buffer<expr> m_aux_types; // types of auxiliary decls
expr m_type;
expr m_value;
level_param_names m_ls;
@ -564,16 +565,16 @@ class definition_cmd_fn {
auto env_n = add_private_name(m_env, m_name, optional<unsigned>(h));
m_env = env_n.first;
m_real_name = env_n.second;
for (expr const & aux : m_aux_decls) {
auto env_n = add_private_name(m_env, local_pp_name(aux), optional<unsigned>(h));
for (name const & aux : m_aux_decls) {
auto env_n = add_private_name(m_env, aux, optional<unsigned>(h));
m_env = env_n.first;
m_real_aux_names.push_back(env_n.second);
}
} else {
name const & ns = get_namespace(m_env);
m_real_name = ns + m_name;
for (expr const & aux : m_aux_decls)
m_real_aux_names.push_back(ns + local_pp_name(aux));
for (name const & aux : m_aux_decls)
m_real_aux_names.push_back(ns + aux);
}
}
@ -646,30 +647,95 @@ class definition_cmd_fn {
return false;
}
void register_decl() {
void register_decl(name const & n, name const & real_n, expr const & type) {
if (m_kind != Example) {
// TODO(Leo): register aux_decls
if (!m_is_private)
m_p.add_decl_index(m_real_name, m_pos, m_p.get_cmd_token(), m_type);
if (m_real_name != m_name)
m_env = add_expr_alias_rec(m_env, m_name, m_real_name);
m_p.add_decl_index(real_n, m_pos, m_p.get_cmd_token(), type);
if (n != real_n)
m_env = add_expr_alias_rec(m_env, n, real_n);
if (m_modifiers.m_is_instance) {
bool persistent = true;
if (m_modifiers.m_priority) {
#if defined(__GNUC__) && !defined(__CLANG__)
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
m_env = add_instance(m_env, m_real_name, *m_modifiers.m_priority, persistent);
m_env = add_instance(m_env, real_n, *m_modifiers.m_priority, persistent);
} else {
m_env = add_instance(m_env, m_real_name, persistent);
m_env = add_instance(m_env, real_n, persistent);
}
}
if (m_modifiers.m_is_coercion)
m_env = add_coercion(m_env, m_real_name, m_p.ios());
m_env = add_coercion(m_env, real_n, m_p.ios());
if (m_is_protected)
m_env = add_protected(m_env, m_real_name);
m_env = add_protected(m_env, real_n);
if (m_modifiers.m_is_reducible)
m_env = set_reducible(m_env, m_real_name, reducible_status::On);
m_env = set_reducible(m_env, real_n, reducible_status::On);
}
}
void register_decl() {
register_decl(m_name, m_real_name, m_type);
for (unsigned i = 0; i < m_aux_decls.size(); i++) {
register_decl(m_aux_decls[i], m_real_aux_names[i], m_aux_types[i]);
}
}
// When compiling mutually recursive equations or equations based on well-found recursion,
// the equations compiler produces a result that combines different definitions.
// We say the result is "tangled". This method untangles it.
// The tangled result is of the form
// Fun (a : A), [equations_result main-value aux-type-1 aux-value-1 ... aux-type-i aux-value-i]
//
// The result is the updated value. The auxiliary definitions (type and value) are stored at m_aux_types and aux_values
expr untangle_definitions(expr const & type, expr const & value, buffer<expr> & aux_values) {
if (is_lambda(value)) {
lean_assert(is_pi(type));
expr r = untangle_definitions(binding_body(type), binding_body(value), aux_values);
lean_assert(aux_values.size() == m_aux_types.size());
for (unsigned i = 0; i < aux_values.size(); i++) {
m_aux_types[i] = mk_pi(binding_name(type), binding_domain(type), m_aux_types[i], binding_info(type));
aux_values[i] = mk_lambda(binding_name(value), binding_domain(value), aux_values[i], binding_info(value));
}
return update_binding(value, binding_domain(value), r);
} else if (is_equations_result(value)) {
lean_assert(get_equations_result_size(value) > 1);
lean_assert(get_equations_result_size(value) % 2 == 1);
for (unsigned i = 1; i < get_equations_result_size(value); i+=2) {
m_aux_types.push_back(get_equations_result(value, i));
aux_values.push_back(get_equations_result(value, i+1));
}
return get_equations_result(value, 0);
} else {
throw exception("invalid declaration, unexpected result produced by equations compiler");
}
}
// Elaborate definitions that contain auxiliary ones nested inside.
// Remark: we do not cache this kind of definition.
// This method will also initialize m_aux_types
void elaborate_multi() {
lean_assert(!m_aux_decls.empty());
level_param_names new_ls;
std::tie(m_type, m_value, new_ls) = m_p.elaborate_definition(m_name, m_type, m_value, m_is_opaque);
new_ls = append(m_ls, new_ls);
lean_assert(m_aux_types.empty());
buffer<expr> aux_values;
m_value = untangle_definitions(m_type, m_value, aux_values);
lean_assert(aux_values.size() == m_aux_types.size());
if (aux_values.size() != m_real_aux_names.size())
throw exception("invalid declaration, failed to compile auxiliary declarations");
if (is_definition()) {
m_env = module::add(m_env, check(m_env, mk_definition(m_env, m_real_name, new_ls,
m_type, m_value, m_is_opaque)));
for (unsigned i = 0; i < aux_values.size(); i++)
m_env = module::add(m_env, check(m_env, mk_definition(m_env, m_real_aux_names[i], new_ls,
m_aux_types[i], aux_values[i], m_is_opaque)));
} else {
m_env = module::add(m_env, check(m_env, mk_theorem(m_real_name, new_ls, m_type, m_value)));
for (unsigned i = 0; i < aux_values.size(); i++)
m_env = module::add(m_env, check(m_env, mk_theorem(m_real_aux_names[i], new_ls,
m_aux_types[i], aux_values[i])));
}
}
@ -681,11 +747,7 @@ class definition_cmd_fn {
m_p.remove_proof_state_info(m_pos, m_p.pos());
if (!m_aux_decls.empty()) {
// TODO(Leo): split equations_result
std::tie(m_type, m_value, new_ls) = m_p.elaborate_definition(m_name, m_type, m_value, m_is_opaque);
new_ls = append(m_ls, new_ls);
m_env = module::add(m_env, check(m_env, mk_definition(m_env, m_real_name, new_ls,
m_type, m_value, m_is_opaque)));
// Remark: we do not cache mutually recursive declarations.
elaborate_multi();
} else if (!is_definition()) {
// Theorems and Examples
auto type_pos = m_p.pos_of(m_type);

View file

@ -1553,6 +1553,8 @@ class equation_compiler_fn {
move_params(prgs, arg_pos);
buffer<expr> rs;
for (unsigned i = 0; i < prgs.size(); i++) {
if (i > 0)
rs.push_back(mlocal_type(prgs[i].m_fn));
// Remark: this loop is very hackish.
// We are "compiling" the code prgs.size() times!
// This is wasteful. We should rewrite this.

View file

@ -42,7 +42,7 @@ expr compile_equations(type_checker & tc, io_state const & ios, expr const & eqn
/** \brief Return true if \c e is an auxiliary macro used to store the result of mutually recursive declarations.
For example, if a set of recursive equations is defining \c n mutually recursive functions, we wrap
the \c n resulting functions with an \c equations_result macro.
the \c n resulting functions (and their types) with an \c equations_result macro.
*/
bool is_equations_result(expr const & e);
unsigned get_equations_result_size(expr const & e);

68
tests/lean/run/eq24.lean Normal file
View file

@ -0,0 +1,68 @@
open nat
inductive tree (A : Type) :=
leaf : A → tree A,
node : tree_list A → tree A
with tree_list :=
nil : tree_list A,
cons : tree A → tree_list A → tree_list A
namespace tree
open tree_list
definition size {A : Type} : tree A → nat
with size_l : tree_list A → nat,
size (leaf a) := 1,
size (node l) := size_l l,
size_l !nil := 0,
size_l (cons t l) := size t + size_l l
variables {A : Type}
theorem size_leaf (a : A) : size (leaf a) = 1 :=
rfl
theorem size_node (l : tree_list A) : size (node l) = size_l l :=
rfl
theorem size_l_nil : size_l (nil A) = 0 :=
rfl
theorem size_l_cons (t : tree A) (l : tree_list A) : size_l (cons t l) = size t + size_l l :=
rfl
definition eq_tree {A : Type} : tree A → tree A → Prop
with eq_tree_list : tree_list A → tree_list A → Prop,
eq_tree (leaf a₁) (leaf a₂) := a₁ = a₂,
eq_tree (node l₁) (node l₂) := eq_tree_list l₁ l₂,
eq_tree _ _ := false,
eq_tree_list !nil !nil := true,
eq_tree_list (cons t₁ l₁) (cons t₂ l₂) := eq_tree t₁ t₂ ∧ eq_tree_list l₁ l₂,
eq_tree_list _ _ := false
theorem eq_tree_leaf (a₁ a₂ : A) : eq_tree (leaf a₁) (leaf a₂) = (a₁ = a₂) :=
rfl
theorem eq_tree_node (l₁ l₂ : tree_list A) : eq_tree (node l₁) (node l₂) = eq_tree_list l₁ l₂ :=
rfl
theorem eq_tree_leaf_node (a₁ : A) (l₂ : tree_list A) : eq_tree (leaf a₁) (node l₂) = false :=
rfl
theorem eq_tree_node_leaf (l₁ : tree_list A) (a₂ : A) : eq_tree (node l₁) (leaf a₂) = false :=
rfl
theorem eq_tree_list_nil : eq_tree_list (nil A) (nil A) = true :=
rfl
theorem eq_tree_list_cons (t₁ t₂ : tree A) (l₁ l₂ : tree_list A) :
eq_tree_list (cons t₁ l₁) (cons t₂ l₂) = (eq_tree t₁ t₂ ∧ eq_tree_list l₁ l₂) :=
rfl
theorem eq_tree_list_cons_nil (t : tree A) (l : tree_list A) : eq_tree_list (cons t l) (nil A) = false :=
rfl
theorem eq_tree_list_nil_cons (t : tree A) (l : tree_list A) : eq_tree_list (nil A) (cons t l) = false :=
rfl
end tree