diff --git a/src/library/tactic/unfold_rec.cpp b/src/library/tactic/unfold_rec.cpp index 037778eb5..0cec773b3 100644 --- a/src/library/tactic/unfold_rec.cpp +++ b/src/library/tactic/unfold_rec.cpp @@ -172,12 +172,14 @@ class unfold_rec_fn : public replace_visitor_aux { unsigned m_major_idx; // position of the major premise in the recursor unsigned m_main_pos; // position of the (recursive) argument in the function being unfolded buffer const & m_rec_arg_pos; // position of the other arguments that are not fixed in the recursion + name m_prod_rec_name; fold_rec_fn(type_checker_ptr & tc, expr const & fn, buffer const & args, rec_kind k, name const & rec_name, unsigned main_pos, buffer const & rec_arg_pos): m_tc(tc), m_fn(fn), m_args(args), m_kind(k), m_rec_name(rec_name), m_major_idx(*inductive::get_elim_major_idx(m_tc->env(), rec_name)), m_main_pos(main_pos), m_rec_arg_pos(rec_arg_pos) { + m_prod_rec_name = inductive::get_elim_name(get_prod_name()); lean_assert(m_main_pos < args.size()); lean_assert(std::all_of(rec_arg_pos.begin(), rec_arg_pos.end(), [&](unsigned pos) { return pos < args.size(); })); } @@ -199,18 +201,21 @@ class unfold_rec_fn : public replace_visitor_aux { return folded_app; } - expr fold_brec(expr const & e, buffer const & args) { - if (args.size() != 3 + m_rec_arg_pos.size()) + expr fold_brec_core(expr const & e, buffer const & args, unsigned prefix_size, unsigned major_pos) { + if (args.size() != prefix_size + m_rec_arg_pos.size()) { throw fold_failed(); + } buffer nested_args; - get_app_args(args[1], nested_args); - if (nested_args.size() != m_major_idx+1) + get_app_args(args[major_pos], nested_args); + + if (nested_args.size() != m_major_idx+1) { throw fold_failed(); + } buffer new_args; new_args.append(m_args); new_args[m_main_pos] = nested_args[m_major_idx]; for (unsigned i = 0; i < m_rec_arg_pos.size(); i++) { - new_args[m_rec_arg_pos[i]] = args[3 + i]; + new_args[m_rec_arg_pos[i]] = args[prefix_size + i]; } expr folded_app = mk_app(m_fn, new_args); if (!m_tc->is_def_eq(folded_app, e).first) @@ -218,6 +223,14 @@ class unfold_rec_fn : public replace_visitor_aux { return folded_app; } + expr fold_brec_pr1(expr const & e, buffer const & args) { + return fold_brec_core(e, args, 3, 1); + } + + expr fold_brec_prod_rec(expr const & e, buffer const & args) { + return fold_brec_core(e, args, 5, 4); + } + virtual expr visit_app(expr const & e) { buffer args; expr fn = get_app_args(e, args); @@ -226,7 +239,12 @@ class unfold_rec_fn : public replace_visitor_aux { if (m_kind == BREC && is_constant(fn) && const_name(fn) == get_prod_pr1_name() && args.size() >= 3) { expr rec_fn = get_app_fn(args[1]); if (is_constant(rec_fn) && const_name(rec_fn) == m_rec_name) - return fold_brec(e, args); + return fold_brec_pr1(e, args); + } + if (m_kind == BREC && is_constant(fn) && const_name(fn) == m_prod_rec_name && args.size() >= 5) { + expr rec_fn = get_app_fn(args[4]); + if (is_constant(rec_fn) && const_name(rec_fn) == m_rec_name) + return fold_brec_prod_rec(e, args); } return visit_app_default(e, fn, args); } diff --git a/tests/lean/run/unfold_tac_bug1.lean b/tests/lean/run/unfold_tac_bug1.lean new file mode 100644 index 000000000..86ae5173f --- /dev/null +++ b/tests/lean/run/unfold_tac_bug1.lean @@ -0,0 +1,38 @@ +/- +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author: Leonardo de Moura + +Show that tail recursive fib is equal to standard one. +-/ +import data.nat +open nat + +definition fib : nat → nat +| 0 := 1 +| 1 := 1 +| (n+2) := fib (n+1) + fib n + +private definition fib_fast_aux : nat → nat → nat → nat +| 0 i j := j +| (succ n) i j := fib_fast_aux n j (j+i) + +lemma fib_fast_aux_lemma : ∀ n m, fib_fast_aux n (fib m) (fib (succ m)) = fib (succ (n + m)) +| 0 m := by rewrite zero_add +| (succ n) m := + begin + have ih : fib_fast_aux n (fib (succ m)) (fib (succ (succ m))) = fib (succ (n + succ m)), from fib_fast_aux_lemma n (succ m), + have h₁ : fib (succ m) + fib m = fib (succ (succ m)), from rfl, + unfold fib_fast_aux, rewrite [h₁, ih, succ_add, add_succ] + end + +definition fib_fast (n: nat) := +fib_fast_aux n 0 1 + +lemma fib_fast_eq_fib : ∀ n, fib_fast n = fib n +| 0 := rfl +| (succ n) := + begin + have h₁ : fib_fast_aux n (fib 0) (fib 1) = fib (succ n), from !fib_fast_aux_lemma, + unfold [fib_fast, fib_fast_aux], krewrite h₁ + end