From 0adacb51910c2597db49fc88eca6be45ee4b875e Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 28 Jun 2014 13:57:36 -0700 Subject: [PATCH] feat(kernel): add infer implicit, and use it to infer implicit arguments of inductive datatype eliminators, and tag whether parameters should be implicit or not in introduction rules in the module inductive_cmd Signed-off-by: Leonardo de Moura --- library/standard/logic.lean | 7 ++--- src/frontends/lean/inductive_cmd.cpp | 23 +++++++++++------ src/kernel/expr.cpp | 38 ++++++++++++++++++++++++++++ src/kernel/expr.h | 13 ++++++++++ src/kernel/inductive/inductive.cpp | 18 ++++--------- tests/lean/run/e14.lean | 8 +++--- tests/lean/run/e15.lean | 9 +++---- tests/lean/run/e16.lean | 9 +++---- tests/lean/run/e17.lean | 4 +-- 9 files changed, 87 insertions(+), 42 deletions(-) diff --git a/library/standard/logic.lean b/library/standard/logic.lean index 020e5ee7c..977e8c126 100644 --- a/library/standard/logic.lean +++ b/library/standard/logic.lean @@ -4,7 +4,7 @@ inductive false : Bool := -- No constructors theorem false_elim (c : Bool) (H : false) -:= @false_rec c H +:= false_rec c H inductive true : Bool := | trivial : true @@ -54,13 +54,10 @@ theorem or_elim (a b c : Bool) (H1 : a ∨ b) (H2 : a → c) (H3 : b → c) : c := or_rec H2 H3 H1 inductive eq {A : Type} (a : A) : A → Bool := -| eq_intro : eq A a a -- TODO: use elaborator in inductive_cmd module, we should not need to type A here +| refl : eq A a a -- TODO: use elaborator in inductive_cmd module, we should not need to type A here infix `=` 50 := eq -theorem refl {A : Type} (a : A) : a = a -:= @eq_intro A a - theorem subst {A : Type} {a b : A} {P : A → Bool} (H1 : a = b) (H2 : P a) : P b := eq_rec H2 H1 diff --git a/src/frontends/lean/inductive_cmd.cpp b/src/frontends/lean/inductive_cmd.cpp index a1a691239..dd9491ee0 100644 --- a/src/frontends/lean/inductive_cmd.cpp +++ b/src/frontends/lean/inductive_cmd.cpp @@ -10,6 +10,7 @@ Author: Leonardo de Moura #include "kernel/type_checker.h" #include "kernel/instantiate.h" #include "kernel/inductive/inductive.h" +#include "kernel/free_vars.h" #include "library/scoped_ext.h" #include "library/locals.h" #include "library/placeholder.h" @@ -24,6 +25,8 @@ static name g_assign(":="); static name g_with("with"); static name g_colon(":"); static name g_bar("|"); +static name g_lcurly("{"); +static name g_rcurly("}"); using inductive::intro_rule; using inductive::inductive_decl; @@ -33,12 +36,6 @@ using inductive::inductive_decl_intros; using inductive::intro_rule_name; using inductive::intro_rule_type; -// Mark all parameters as implicit -static void make_implicit(buffer & ps) { - for (parameter & p : ps) - p.m_bi = mk_implicit_binder_info(); -} - // Make sure that every inductive datatype (in decls) occurring in \c type has // the universe levels \c lvl_params and section parameters \c section_params static expr fix_inductive_occs(expr const & type, buffer const & decls, @@ -181,6 +178,8 @@ environment inductive_cmd(parser & p) { bool first = true; buffer ls_buffer; name_map id_to_short_id; + // store intro rule name that are markes for relaxed implicit argument inference. + name_set relaxed_implicit_inference; unsigned num_params = 0; bool explicit_levels = false; buffer decls; @@ -250,7 +249,6 @@ environment inductive_cmd(parser & p) { params_pos); } } - make_implicit(ps); // parameters are implicit for introduction rules // parse introduction rules p.check_token_next(g_assign, "invalid inductive declaration, ':=' expected"); buffer intros; @@ -260,9 +258,17 @@ environment inductive_cmd(parser & p) { check_atomic(intro_id); name full_intro_id = ns + intro_id; id_to_short_id.insert(full_intro_id, intro_id); + bool strict = true; + if (p.curr_is_token(g_lcurly)) { + p.next(); + p.check_token_next(g_rcurly, "invalid introduction rule, '}' expected"); + strict = false; + relaxed_implicit_inference.insert(full_intro_id); + } p.check_token_next(g_colon, "invalid introduction rule, ':' expected"); expr intro_type = p.parse_scoped_expr(ps, lenv); intro_type = p.pi_abstract(ps, intro_type); + intro_type = infer_implicit(intro_type, ps.size(), strict); intros.push_back(intro_rule(full_intro_id, intro_type)); } decls.push_back(inductive_decl(full_id, type, to_list(intros.begin(), intros.end()))); @@ -294,7 +300,6 @@ environment inductive_cmd(parser & p) { p.pi_abstract(section_params, inductive_decl_type(d)), inductive_decl_intros(d)); } - make_implicit(section_params); // Add section_params to introduction rules type, and also "fix" // occurrences of inductive types. for (inductive_decl & d : decls) { @@ -303,6 +308,8 @@ environment inductive_cmd(parser & p) { expr type = intro_rule_type(ir); type = fix_inductive_occs(type, decls, ls_buffer, section_params); type = p.pi_abstract(section_params, type); + bool strict = relaxed_implicit_inference.contains(intro_rule_name(ir)); + type = infer_implicit(type, section_params.size(), strict); new_irs.push_back(intro_rule(intro_rule_name(ir), type)); } d = inductive_decl(inductive_decl_name(d), diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index ef2751731..b48b72750 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -489,6 +489,13 @@ expr update_binding(expr const & e, expr const & new_domain, expr const & new_bo return e; } +expr update_binding(expr const & e, expr const & new_domain, expr const & new_body, binder_info const & bi) { + if (!is_eqp(binding_domain(e), new_domain) || !is_eqp(binding_body(e), new_body) || bi != binding_info(e)) + return copy_tag(e, mk_binding(e.kind(), binding_name(e), new_domain, new_body, bi)); + else + return e; +} + expr update_mlocal(expr const & e, expr const & new_type) { if (is_eqp(mlocal_type(e), new_type)) return e; @@ -583,4 +590,35 @@ static macro_definition g_let_macro_definition(new let_macro_definition_cell()); expr mk_let_macro(expr const & e) { return mk_macro(g_let_macro_definition, 1, &e); } bool is_let_macro(expr const & e) { return is_macro(e) && macro_def(e) == g_let_macro_definition; } expr let_macro_arg(expr const & e) { lean_assert(is_let_macro(e)); return macro_arg(e, 0); } + +static bool has_free_var_in_domain(expr const & b, unsigned vidx) { + if (is_pi(b)) { + return has_free_var(binding_domain(b), vidx) || has_free_var_in_domain(binding_body(b), vidx+1); + } else { + return false; + } +} + +expr infer_implicit(expr const & t, unsigned num_params, bool strict) { + if (num_params == 0) { + return t; + } else if (is_pi(t)) { + expr new_body = infer_implicit(binding_body(t), num_params-1, strict); + if (binding_info(t).is_implicit() || binding_info(t).is_strict_implicit()) { + // argument is already marked as implicit + return update_binding(t, binding_domain(t), new_body); + } else if ((strict && has_free_var_in_domain(new_body, 0)) || + (!strict && has_free_var(new_body, 0))) { + return update_binding(t, binding_domain(t), new_body, mk_implicit_binder_info()); + } else { + return update_binding(t, binding_domain(t), new_body); + } + } else { + return t; + } +} + +expr infer_implicit(expr const & t, bool strict) { + return infer_implicit(t, std::numeric_limits::max(), strict); +} } diff --git a/src/kernel/expr.h b/src/kernel/expr.h index d9b5c49cc..23cfc90ce 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -641,6 +641,7 @@ expr update_app(expr const & e, expr const & new_fn, expr const & new_arg); expr update_rev_app(expr const & e, unsigned num, expr const * new_args); template expr update_rev_app(expr const & e, C const & c) { return update_rev_app(e, c.size(), c.data()); } expr update_binding(expr const & e, expr const & new_domain, expr const & new_body); +expr update_binding(expr const & e, expr const & new_domain, expr const & new_body, binder_info const & bi); expr update_mlocal(expr const & e, expr const & new_type); expr update_sort(expr const & e, level const & new_level); expr update_constant(expr const & e, levels const & new_levels); @@ -655,5 +656,17 @@ expr let_macro_arg(expr const & e); std::string const & get_let_macro_opcode(); // ======================================= +// ======================================= +// Implicit argument inference +/** + \brief Given \c t of the form Pi (x_1 : A_1) ... (x_k : A_k), B, + mark the first \c num_params as implicit if they are not already marked, and + they occur in the remaining arguments. If \c strict is false, then we + also mark it implicit if it occurs in \c B. +*/ +expr infer_implicit(expr const & t, unsigned num_params, bool strict); +expr infer_implicit(expr const & t, bool strict); +// ======================================= + std::ostream & operator<<(std::ostream & out, expr const & e); } diff --git a/src/kernel/inductive/inductive.cpp b/src/kernel/inductive/inductive.cpp index 827cd4c1a..c39e04b48 100644 --- a/src/kernel/inductive/inductive.cpp +++ b/src/kernel/inductive/inductive.cpp @@ -622,15 +622,10 @@ struct add_inductive_fn { expr C_app = mk_app(info.m_C, info.m_indices); if (m_dep_elim) C_app = mk_app(C_app, info.m_major_premise); - expr elim_ty = Pi(info.m_major_premise, C_app); - unsigned i = info.m_indices.size(); - while (i > 0) { - --i; - elim_ty = Pi(info.m_indices[i], elim_ty, mk_implicit_binder_info()); - } + elim_ty = Pi(info.m_indices, elim_ty); // abstract all introduction rules - i = get_num_its(); + unsigned i = get_num_its(); while (i > 0) { --i; unsigned j = m_elim_info[i].m_minor_premises.size(); @@ -643,13 +638,10 @@ struct add_inductive_fn { i = get_num_its(); while (i > 0) { --i; - elim_ty = Pi(m_elim_info[i].m_C, elim_ty, mk_implicit_binder_info()); - } - i = m_param_consts.size(); - while (i > 0) { - --i; - elim_ty = Pi(m_param_consts[i], elim_ty, mk_implicit_binder_info()); + elim_ty = Pi(m_elim_info[i].m_C, elim_ty); } + elim_ty = Pi(m_param_consts, elim_ty); + elim_ty = infer_implicit(elim_ty, true /* strict */); m_env = m_env.add(check(m_env, mk_var_decl(get_elim_name(d), get_elim_level_param_names(), elim_ty))); } diff --git a/tests/lean/run/e14.lean b/tests/lean/run/e14.lean index 5e430f615..40bc50785 100644 --- a/tests/lean/run/e14.lean +++ b/tests/lean/run/e14.lean @@ -3,8 +3,8 @@ inductive nat : Type := | succ : nat → nat inductive list (A : Type) : Type := -| nil : list A -| cons : A → list A → list A +| nil {} : list A +| cons : A → list A → list A check nil check nil.{1} @@ -14,8 +14,8 @@ check @nil nat check cons zero nil inductive vector (A : Type) : nat → Type := -| vnil : vector A zero -| vcons : forall {n : nat}, A → vector A n → vector A (succ n) +| vnil {} : vector A zero +| vcons : forall {n : nat}, A → vector A n → vector A (succ n) check vcons zero vnil variable n : nat diff --git a/tests/lean/run/e15.lean b/tests/lean/run/e15.lean index b775260d7..1bcac11f1 100644 --- a/tests/lean/run/e15.lean +++ b/tests/lean/run/e15.lean @@ -3,8 +3,8 @@ inductive nat : Type := | succ : nat → nat inductive list (A : Type) : Type := -| nil : list A -| cons : A → list A → list A +| nil {} : list A +| cons : A → list A → list A check nil check nil.{1} @@ -14,8 +14,8 @@ check @nil nat check cons zero nil inductive vector (A : Type) : nat → Type := -| vnil : vector A zero -| vcons : forall {n : nat}, A → vector A n → vector A (succ n) +| vnil {} : vector A zero +| vcons : forall {n : nat}, A → vector A n → vector A (succ n) check vcons zero vnil variable n : nat @@ -25,4 +25,3 @@ check vector_rec definition vector_to_list {A : Type} {n : nat} (v : vector A n) : list A := vector_rec nil (fun (n : nat) (a : A) (v : vector A n) (l : list A), cons a l) v - diff --git a/tests/lean/run/e16.lean b/tests/lean/run/e16.lean index 81fa79c6e..2484963ec 100644 --- a/tests/lean/run/e16.lean +++ b/tests/lean/run/e16.lean @@ -3,8 +3,8 @@ inductive nat : Type := | succ : nat → nat inductive list (A : Type) : Type := -| nil : list A -| cons : A → list A → list A +| nil {} : list A +| cons : A → list A → list A check nil check nil.{1} @@ -14,8 +14,8 @@ check @nil nat check cons zero nil inductive vector (A : Type) : nat → Type := -| vnil : vector A zero -| vcons : forall {n : nat}, A → vector A n → vector A (succ n) +| vnil {} : vector A zero +| vcons : forall {n : nat}, A → vector A n → vector A (succ n) check vcons zero vnil variable n : nat @@ -25,4 +25,3 @@ check vector_rec definition vector_to_list {A : Type} {n : nat} (v : vector A n) : list A := vector_rec nil (fun n a v l, cons a l) v - diff --git a/tests/lean/run/e17.lean b/tests/lean/run/e17.lean index 6d4e14d5b..849db414b 100644 --- a/tests/lean/run/e17.lean +++ b/tests/lean/run/e17.lean @@ -3,8 +3,8 @@ inductive nat : Type := | succ : nat → nat inductive list (A : Type) : Type := -| nil : list A -| cons : A → list A → list A +| nil {} : list A +| cons : A → list A → list A inductive int : Type := | of_nat : nat → int