From de369a0a0a0181a884ee6485a8ae1f487613816d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 1 May 2015 15:47:15 -0700 Subject: [PATCH] feat(library/tactic/injection_tactic): improve 'injection' tactic see issue #500 --- library/data/list/perm.lean | 6 +- src/library/tactic/injection_tactic.cpp | 294 ++++++++++++++---------- tests/lean/hott/inj_tac.hlean | 10 +- tests/lean/run/inj_tac.lean | 16 +- 4 files changed, 194 insertions(+), 132 deletions(-) diff --git a/library/data/list/perm.lean b/library/data/list/perm.lean index 0cbfb095b..6fbce67e1 100644 --- a/library/data/list/perm.lean +++ b/library/data/list/perm.lean @@ -297,15 +297,13 @@ match l₂ with | [h₁] := λ e H₁ H₂ H₃, begin rewrite [append_cons at e, append_nil_left at e], - injection e with a_eq_h₁ rest, - injection rest with b_eq_c l₁_eq_l₃, + injection e with a_eq_h₁ b_eq_c l₁_eq_l₃, rewrite [a_eq_h₁ at H₂, b_eq_c at H₂, l₁_eq_l₃ at H₂], exact H₂ rfl rfl rfl end | h₁::h₂::t₂ := λ e H₁ H₂ H₃, begin - injection e with a_eq_h₁ rest, - injection rest with b_eq_h₂ l₁_eq, + injection e with a_eq_h₁ b_eq_h₂ l₁_eq, rewrite [a_eq_h₁ at H₃, b_eq_h₂ at H₃], exact H₃ t₂ rfl l₁_eq end diff --git a/src/library/tactic/injection_tactic.cpp b/src/library/tactic/injection_tactic.cpp index 468f1d0cc..73919c54e 100644 --- a/src/library/tactic/injection_tactic.cpp +++ b/src/library/tactic/injection_tactic.cpp @@ -13,37 +13,50 @@ Author: Leonardo de Moura #include "library/tactic/elaborate.h" #include "library/tactic/expr_to_tactic.h" #include "library/tactic/apply_tactic.h" +#include "library/tactic/clear_tactic.h" namespace lean { +tactic injection_tactic_core(expr const & e, unsigned num, list const & ids, bool report_errors); + +// Return true iff lhs and rhs are of the form (f ...) f is a constructor +bool is_injection_target(type_checker & tc, expr lhs, expr rhs) { + environment const & env = tc.env(); + lhs = tc.whnf(lhs).first; + rhs = tc.whnf(rhs).first; + expr A = tc.whnf(tc.infer(lhs).first).first; + expr const & I = get_app_fn(A); + if (!is_constant(I) || !inductive::is_inductive_decl(env, const_name(I))) + return false; + expr lhs_fn = get_app_fn(lhs); + expr rhs_fn = get_app_fn(rhs); + return + is_constant(lhs_fn) && is_constant(rhs_fn) && + const_name(lhs_fn) == const_name(rhs_fn) && + inductive::is_intro_rule(env, const_name(lhs_fn)); +} + /** \brief Introduce num hypotheses, if _ns is not nil use it to name the hypothesis, New hypothesis of the form (a = a) and (a == a) are discarded. New hypothesis of the form (a == b) where (a b : A), are converted into (a = b). */ -tactic intros_num_tactic(list _ns, unsigned num) { - auto fn = [=](environment const & env, io_state const &, proof_state const & s) { +tactic intros_num_tactic(unsigned num, list _ns) { + auto fn = [=](environment const & env, io_state const & ios, proof_state const & s) { if (num == 0) - return some_proof_state(s); + return proof_state_seq(s); list ns = _ns; goals const & gs = s.get_goals(); - if (empty(gs)) { - throw_no_goal_if_enabled(s); - return optional(); - } + if (empty(gs)) + return proof_state_seq(); goal const & g = head(gs); name_generator ngen = s.get_ngen(); auto tc = mk_type_checker(env, ngen.mk_child(), s.relax_main_opaque()); expr t = g.get_type(); expr m = g.get_meta(); - buffer hyps; - g.get_hyps(hyps); - buffer new_hyps; // extra hypotheses for the new goal - buffer args; // arguments to be provided to new goal - buffer intros; // locals being introduced auto mk_name = [&](name const & n) { if (is_nil(ns)) { - return get_unused_name(n, new_hyps); + return g.get_unused_name(n); } else { name r = head(ns); ns = tail(ns); @@ -51,75 +64,151 @@ tactic intros_num_tactic(list _ns, unsigned num) { } }; - // introduce a value of type t - auto add_intro = [&](expr const & t) { - expr i = mk_local(ngen.next(), t); - intros.push_back(i); - return i; + auto keep_hyp = [&]() { + expr H = mk_local(mk_name(binding_name(t)), binding_domain(t)); + t = instantiate(binding_body(t), H); + m = mk_app(m, H); + proof_state new_s(s, cons(goal(m, t), tail(gs)), ngen); + return intros_num_tactic(num-1, ns)(env, ios, new_s); }; - auto add_hyp = [&](name const & n, expr const & t) { - expr l = mk_local(mk_name(n), t); - new_hyps.push_back(l); - intros.push_back(l); - args.push_back(l); - return l; - }; - - try { - for (unsigned i = 0; i < num; i++) { - t = tc->ensure_pi(t).first; - name const & Hname = binding_name(t); - constraint_seq Hcs; - expr Htype = tc->whnf(binding_domain(t), Hcs); - optional new_Htype; - expr A, B, lhs, rhs; - if (!closed(binding_body(t))) { - // rest depends on Hname : Htype - expr H = add_hyp(Hname, Htype); - t = instantiate(binding_body(t), H); - } else { - if (is_eq(Htype, lhs, rhs)) { - if (!tc->is_def_eq(lhs, rhs, justification(), Hcs) || Hcs) - add_hyp(Hname, Htype); - else - add_intro(Htype); // discard - } else if (is_standard(env) && is_heq(Htype, A, lhs, B, rhs)) { - if (tc->is_def_eq(A, B, justification(), Hcs) && !Hcs) { - if (!tc->is_def_eq(lhs, rhs, justification(), Hcs) || Hcs) { - // convert to homogenous equality - expr H = mk_local(ngen.next(), Htype); - expr newHtype = mk_eq(*tc, lhs, rhs); - expr newH = mk_local(mk_name(Hname), newHtype); - new_hyps.push_back(newH); - intros.push_back(H); - levels heq_lvl = const_levels(get_app_fn(Htype)); - args.push_back(mk_app(mk_constant(get_heq_to_eq_name(), heq_lvl), A, lhs, rhs, H)); - } else { - add_intro(Htype); // discard - } - } else { - add_hyp(Hname, Htype); - } - } else { - add_hyp(Hname, Htype); - } - t = binding_body(t); - } - } + auto discard_hyp = [&]() { + expr new_meta = g.mk_meta(ngen.next(), binding_body(t)); + goal new_goal(new_meta, binding_body(t)); substitution new_subst = s.get_subst(); - expr new_mvar = mk_metavar(ngen.next(), Pi(hyps, Pi(new_hyps, t))); - expr new_aux = mk_app(new_mvar, hyps); - expr new_meta = mk_app(new_aux, new_hyps); - goal new_goal(new_meta, t); - assign(new_subst, g, Fun(intros, mk_app(new_aux, args))); - return some_proof_state(proof_state(s, cons(new_goal, tail(gs)), new_subst, ngen)); - } catch (exception &) { - return none_proof_state(); + assign(new_subst, g, mk_lambda(binding_name(t), binding_domain(t), new_meta)); + proof_state new_s(s, cons(new_goal, tail(gs)), new_subst, ngen); + return intros_num_tactic(num-1, ns)(env, ios, new_s); + }; + + t = tc->ensure_pi(t).first; + + // if goal depends on hypothesis, we keep it + if (!closed(binding_body(t))) + return keep_hyp(); + + constraint_seq cs; + expr Htype = tc->whnf(binding_domain(t), cs); + + // new unification constraints were generated, so we keep hypothesis + if (cs) + return keep_hyp(); + + expr lhs, rhs; + if (is_eq(Htype, lhs, rhs)) { + // equalities of the form (a = a) are discarded + if (tc->is_def_eq(lhs, rhs, justification(), cs) && !cs) { + return discard_hyp(); + } else if (is_injection_target(*tc, lhs, rhs)) { + // apply injection recursively + name Hname = ngen.next(); + expr H = mk_local(Hname, binding_domain(t)); + t = binding_body(t); + m = mk_app(m, H); + proof_state new_s(s, cons(goal(m, t), tail(gs)), ngen); + return then(injection_tactic_core(H, num-1, ns, false), + clear_tactic(Hname))(env, ios, new_s); + } else { + return keep_hyp(); + } + } + + expr A, B; + if (is_standard(env) && is_heq(Htype, A, lhs, B, rhs)) { + if (tc->is_def_eq(A, B, justification(), cs) && !cs) { + // since types A and B are definitionally equal, we convert to homogeneous + expr new_eq = mk_eq(*tc, lhs, rhs); + expr new_type = mk_pi(binding_name(t), new_eq, binding_body(t)); + expr new_meta = g.mk_meta(ngen.next(), new_type); + goal new_goal(new_meta, new_type); + expr H = mk_local(ngen.next(), binding_domain(t)); + levels heq_lvl = const_levels(get_app_fn(Htype)); + expr arg = mk_app(mk_constant(get_heq_to_eq_name(), heq_lvl), A, lhs, rhs, H); + expr V = Fun(H, mk_app(new_meta, arg)); + substitution new_subst = s.get_subst(); + assign(new_subst, g, V); + proof_state new_s(s, cons(new_goal, tail(gs)), new_subst, ngen); + return intros_num_tactic(num, ns)(env, ios, new_s); + } else { + return keep_hyp(); + } + } + + // hypothesis is not an equality + return keep_hyp(); + }; + return tactic(fn); +} + +tactic injection_tactic_core(expr const & e, unsigned num, list const & ids, bool report_errors) { + auto fn = [=](environment const & env, io_state const & ios, proof_state const & s) { + goals const & gs = s.get_goals(); + if (!gs) { + throw_no_goal_if_enabled(s); + return proof_state_seq(); + } + expr t = head(gs).get_type(); + constraint_seq cs; + name_generator ngen = s.get_ngen(); + auto tc = mk_type_checker(env, ngen.mk_child(), s.relax_main_opaque()); + expr e_type = tc->whnf(tc->infer(e, cs), cs); + expr lhs, rhs; + if (!is_eq(e_type, lhs, rhs)) { + if (report_errors) { + throw_tactic_exception_if_enabled(s, "invalid 'injection' tactic, " + "given argument is not an equality proof"); + return proof_state_seq(); + } + return intros_num_tactic(num, ids)(env, ios, s); + } + lhs = tc->whnf(lhs, cs); + rhs = tc->whnf(rhs, cs); + expr A = tc->whnf(tc->infer(lhs, cs), cs); + buffer I_args; + expr I = get_app_args(A, I_args); + if (!is_constant(I) || !inductive::is_inductive_decl(env, const_name(I))) { + if (report_errors) { + throw_tactic_exception_if_enabled(s, "invalid 'injection' tactic, " + "it is not an equality between inductive values"); + return proof_state_seq(); + } + return intros_num_tactic(num, ids)(env, ios, s); + } + expr lhs_fn = get_app_fn(lhs); + expr rhs_fn = get_app_fn(rhs); + if (!is_constant(lhs_fn) || !is_constant(rhs_fn) || const_name(lhs_fn) != const_name(rhs_fn) || + !inductive::is_intro_rule(env, const_name(lhs_fn))) { + if (report_errors) { + throw_tactic_exception_if_enabled(s, "invalid 'injection' tactic, " + "the given equality is not of the form (f ...) = (f ...) " + "where f is a constructor"); + return proof_state_seq(); + } + return intros_num_tactic(num, ids)(env, ios, s); + } + unsigned num_params = *inductive::get_num_params(env, const_name(I)); + unsigned cnstr_arity = get_arity(env.get(const_name(lhs_fn)).get_type()); + lean_assert(cnstr_arity >= num_params); + unsigned num_new_eqs = cnstr_arity - num_params; + level t_lvl = sort_level(tc->ensure_type(t, cs)); + expr N = mk_constant(name(const_name(I), "no_confusion"), cons(t_lvl, const_levels(I))); + N = mk_app(mk_app(N, I_args), t, lhs, rhs, e); + proof_state new_s(s, ngen); + if (is_standard(env)) { + tactic tac = then(take(apply_tactic_core(N, cs), 1), + intros_num_tactic(num + num_new_eqs, ids)); + return tac(env, ios, new_s); + } else { + level n_lvl = mk_meta_univ(tc->mk_fresh_name()); + expr lift_down = mk_app(mk_constant(get_lift_down_name(), {t_lvl, n_lvl}), t); + tactic tac = then(take(apply_tactic_core(lift_down), 1), + then(take(apply_tactic_core(N, cs), 1), + intros_num_tactic(num + num_new_eqs, ids))); + return tac(env, ios, new_s); } }; - return tactic01(fn); -} + return tactic(fn); +}; tactic injection_tactic(elaborate_fn const & elab, expr const & e, list const & ids) { auto fn = [=](environment const & env, io_state const & ios, proof_state const & s) { @@ -134,54 +223,7 @@ tactic injection_tactic(elaborate_fn const & elab, expr const & e, list co bool enforce_type = false; if (optional new_e = elaborate_with_respect_to(env, ios, elab, new_s, e, none_expr(), report_unassigned, enforce_type)) { - constraint_seq cs; - name_generator ngen = new_s.get_ngen(); - auto tc = mk_type_checker(env, ngen.mk_child(), new_s.relax_main_opaque()); - expr new_e_type = tc->whnf(tc->infer(*new_e, cs), cs); - expr lhs, rhs; - if (!is_eq(new_e_type, lhs, rhs)) { - throw_tactic_exception_if_enabled(new_s, "invalid 'injection' tactic, " - "given argument is not an equality proof"); - return proof_state_seq(); - } - lhs = tc->whnf(lhs, cs); - rhs = tc->whnf(rhs, cs); - expr A = tc->whnf(tc->infer(lhs, cs), cs); - buffer I_args; - expr I = get_app_args(A, I_args); - if (!is_constant(I) || !inductive::is_inductive_decl(env, const_name(I))) { - throw_tactic_exception_if_enabled(new_s, "invalid 'injection' tactic, " - "it is not an equality between inductive values"); - return proof_state_seq(); - } - expr lhs_fn = get_app_fn(lhs); - expr rhs_fn = get_app_fn(rhs); - if (!is_constant(lhs_fn) || !is_constant(rhs_fn) || const_name(lhs_fn) != const_name(rhs_fn) || - !inductive::is_intro_rule(env, const_name(lhs_fn))) { - throw_tactic_exception_if_enabled(new_s, "invalid 'injection' tactic, " - "the given equality is not of the form (f ...) = (f ...) " - "where f is a constructor"); - return proof_state_seq(); - } - unsigned num_params = *inductive::get_num_params(env, const_name(I)); - unsigned cnstr_arity = get_arity(env.get(const_name(lhs_fn)).get_type()); - lean_assert(cnstr_arity >= num_params); - unsigned num_new_eqs = cnstr_arity - num_params; - level t_lvl = sort_level(tc->ensure_type(t, cs)); - expr N = mk_constant(name(const_name(I), "no_confusion"), cons(t_lvl, const_levels(I))); - N = mk_app(mk_app(N, I_args), t, lhs, rhs, *new_e); - if (is_standard(env)) { - tactic tac = then(apply_tactic_core(N, cs), - intros_num_tactic(ids, num_new_eqs)); - return tac(env, ios, new_s); - } else { - level n_lvl = mk_meta_univ(tc->mk_fresh_name()); - expr lift_down = mk_app(mk_constant(get_lift_down_name(), {t_lvl, n_lvl}), t); - tactic tac = then(apply_tactic_core(lift_down), - then(apply_tactic_core(N, cs), - intros_num_tactic(ids, num_new_eqs))); - return tac(env, ios, new_s); - } + return injection_tactic_core(*new_e, 0, ids, true)(env, ios, new_s); } else { return proof_state_seq(); } @@ -195,7 +237,7 @@ void initialize_injection_tactic() { check_tactic_expr(app_arg(app_fn(e)), "invalid 'injection' tactic, invalid argument"); buffer ids; get_tactic_id_list_elements(app_arg(e), ids, "invalid 'injection' tactic, list of identifiers expected"); - return injection_tactic(fn, get_tactic_expr_expr(app_arg(app_fn(e))), to_list(ids)); + return take(injection_tactic(fn, get_tactic_expr_expr(app_arg(app_fn(e))), to_list(ids)), 1); }); } diff --git a/tests/lean/hott/inj_tac.hlean b/tests/lean/hott/inj_tac.hlean index a556713c0..622bb1123 100644 --- a/tests/lean/hott/inj_tac.hlean +++ b/tests/lean/hott/inj_tac.hlean @@ -19,6 +19,14 @@ open prod example (A : Type) (a₁ a₂ a₃ b₁ b₂ b₃ : A) : (a₁, a₂, a₃) = (b₁, b₂, b₃) → b₁ = a₁ := begin - intro H, injection H with H₁, injection H₁ with a₁b₁, + intro H, injection H with a₁b₁ a₂b₂ a₃b₃, rewrite a₁b₁ end + +example (a₁ a₂ a₃ b₁ b₂ b₃ : nat) : (a₁+2, a₂+3, a₃+1) = (b₁+2, b₂+2, b₃+2) → b₁ = a₁ × a₃ = b₃+1 := +begin + intro H, injection H with a₁b₁ sa₂b₂ a₃sb₃, + esimp at *, + rewrite [a₁b₁, a₃sb₃], split, + repeat trivial +end diff --git a/tests/lean/run/inj_tac.lean b/tests/lean/run/inj_tac.lean index 40f7790c9..d85ed066d 100644 --- a/tests/lean/run/inj_tac.lean +++ b/tests/lean/run/inj_tac.lean @@ -24,6 +24,20 @@ end example (A : Type) (a₁ a₂ a₃ b₁ b₂ b₃ : A) : (a₁, a₂, a₃) = (b₁, b₂, b₃) → b₁ = a₁ := begin - intro H, injection H with H₁, injection H₁ with a₁b₁, + intro H, injection H with a₁b₁ a₂b₂ a₃b₃, rewrite a₁b₁ end + +example (A : Type) (a₁ a₂ a₃ b₁ b₂ b₃ : A) : (a₁ :: [], a₂, a₃) = (b₁ :: [], b₂, b₃) → b₁ = a₁ := +begin + intro H, injection H with a₁b₁ a₂b₂ a₃b₃, + rewrite a₁b₁ +end + +example (a₁ a₂ a₃ b₁ b₂ b₃ : nat) : (a₁+2, a₂+3, a₃+1) = (b₁+2, b₂+2, b₃+2) → b₁ = a₁ ∧ a₃ = b₃+1 ∧ b₂ = succ a₂ := +begin + intro H, injection H with a₁b₁ sa₂b₂ a₃sb₃, + esimp at *, + rewrite [a₁b₁, a₃sb₃, sa₂b₂], + repeat (split | esimp) +end