fix(library/tactic/unfold_rec): add annother brec pattern that should be checked in the unfold recursive definition tactic

This commit is contained in:
Leonardo de Moura 2015-07-10 22:16:23 -04:00
parent d939509135
commit 554a42b407
2 changed files with 62 additions and 6 deletions

View file

@ -172,12 +172,14 @@ class unfold_rec_fn : public replace_visitor_aux {
unsigned m_major_idx; // position of the major premise in the recursor
unsigned m_main_pos; // position of the (recursive) argument in the function being unfolded
buffer<unsigned> const & m_rec_arg_pos; // position of the other arguments that are not fixed in the recursion
name m_prod_rec_name;
fold_rec_fn(type_checker_ptr & tc, expr const & fn, buffer<expr> const & args, rec_kind k, name const & rec_name,
unsigned main_pos, buffer<unsigned> const & rec_arg_pos):
m_tc(tc), m_fn(fn), m_args(args), m_kind(k), m_rec_name(rec_name),
m_major_idx(*inductive::get_elim_major_idx(m_tc->env(), rec_name)),
m_main_pos(main_pos), m_rec_arg_pos(rec_arg_pos) {
m_prod_rec_name = inductive::get_elim_name(get_prod_name());
lean_assert(m_main_pos < args.size());
lean_assert(std::all_of(rec_arg_pos.begin(), rec_arg_pos.end(), [&](unsigned pos) { return pos < args.size(); }));
}
@ -199,18 +201,21 @@ class unfold_rec_fn : public replace_visitor_aux {
return folded_app;
}
expr fold_brec(expr const & e, buffer<expr> const & args) {
if (args.size() != 3 + m_rec_arg_pos.size())
expr fold_brec_core(expr const & e, buffer<expr> const & args, unsigned prefix_size, unsigned major_pos) {
if (args.size() != prefix_size + m_rec_arg_pos.size()) {
throw fold_failed();
}
buffer<expr> nested_args;
get_app_args(args[1], nested_args);
if (nested_args.size() != m_major_idx+1)
get_app_args(args[major_pos], nested_args);
if (nested_args.size() != m_major_idx+1) {
throw fold_failed();
}
buffer<expr> new_args;
new_args.append(m_args);
new_args[m_main_pos] = nested_args[m_major_idx];
for (unsigned i = 0; i < m_rec_arg_pos.size(); i++) {
new_args[m_rec_arg_pos[i]] = args[3 + i];
new_args[m_rec_arg_pos[i]] = args[prefix_size + i];
}
expr folded_app = mk_app(m_fn, new_args);
if (!m_tc->is_def_eq(folded_app, e).first)
@ -218,6 +223,14 @@ class unfold_rec_fn : public replace_visitor_aux {
return folded_app;
}
expr fold_brec_pr1(expr const & e, buffer<expr> const & args) {
return fold_brec_core(e, args, 3, 1);
}
expr fold_brec_prod_rec(expr const & e, buffer<expr> const & args) {
return fold_brec_core(e, args, 5, 4);
}
virtual expr visit_app(expr const & e) {
buffer<expr> args;
expr fn = get_app_args(e, args);
@ -226,7 +239,12 @@ class unfold_rec_fn : public replace_visitor_aux {
if (m_kind == BREC && is_constant(fn) && const_name(fn) == get_prod_pr1_name() && args.size() >= 3) {
expr rec_fn = get_app_fn(args[1]);
if (is_constant(rec_fn) && const_name(rec_fn) == m_rec_name)
return fold_brec(e, args);
return fold_brec_pr1(e, args);
}
if (m_kind == BREC && is_constant(fn) && const_name(fn) == m_prod_rec_name && args.size() >= 5) {
expr rec_fn = get_app_fn(args[4]);
if (is_constant(rec_fn) && const_name(rec_fn) == m_rec_name)
return fold_brec_prod_rec(e, args);
}
return visit_app_default(e, fn, args);
}

View file

@ -0,0 +1,38 @@
/-
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
Show that tail recursive fib is equal to standard one.
-/
import data.nat
open nat
definition fib : nat → nat
| 0 := 1
| 1 := 1
| (n+2) := fib (n+1) + fib n
private definition fib_fast_aux : nat → nat → nat → nat
| 0 i j := j
| (succ n) i j := fib_fast_aux n j (j+i)
lemma fib_fast_aux_lemma : ∀ n m, fib_fast_aux n (fib m) (fib (succ m)) = fib (succ (n + m))
| 0 m := by rewrite zero_add
| (succ n) m :=
begin
have ih : fib_fast_aux n (fib (succ m)) (fib (succ (succ m))) = fib (succ (n + succ m)), from fib_fast_aux_lemma n (succ m),
have h₁ : fib (succ m) + fib m = fib (succ (succ m)), from rfl,
unfold fib_fast_aux, rewrite [h₁, ih, succ_add, add_succ]
end
definition fib_fast (n: nat) :=
fib_fast_aux n 0 1
lemma fib_fast_eq_fib : ∀ n, fib_fast n = fib n
| 0 := rfl
| (succ n) :=
begin
have h₁ : fib_fast_aux n (fib 0) (fib 1) = fib (succ n), from !fib_fast_aux_lemma,
unfold [fib_fast, fib_fast_aux], krewrite h₁
end