From a9515ac7a47891d863a2136fe9048d267441ffd5 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 8 Jul 2015 21:08:24 -0400 Subject: [PATCH] feat(library/tactic/rewrite_tactic): try to fold nested recursive applications after unfolding a recursive function See issue #692. The implementation still has some rough spots. It is not clear what the right semantic is. Moreover, the folds in e_closure could not be eliminated automatically. --- hott/types/hprop_trunc.hlean | 2 +- hott/types/pointed.hlean | 4 +- library/data/real/order.lean | 5 +- library/data/stream.lean | 6 +- library/logic/examples/cont.lean | 2 +- src/frontends/lean/parse_rewrite_tactic.cpp | 24 +- src/library/tactic/CMakeLists.txt | 2 +- src/library/tactic/rewrite_tactic.cpp | 82 +++-- src/library/tactic/rewrite_tactic.h | 2 +- src/library/tactic/unfold_rec.cpp | 338 ++++++++++++++++++++ src/library/tactic/unfold_rec.h | 11 + tests/lean/run/esimp1.lean | 2 + tests/lean/unfold_rec.lean | 50 +++ tests/lean/unfold_rec.lean.expected.out | 20 ++ 14 files changed, 498 insertions(+), 52 deletions(-) create mode 100644 src/library/tactic/unfold_rec.cpp create mode 100644 src/library/tactic/unfold_rec.h create mode 100644 tests/lean/unfold_rec.lean create mode 100644 tests/lean/unfold_rec.lean.expected.out diff --git a/hott/types/hprop_trunc.hlean b/hott/types/hprop_trunc.hlean index a1e60dd53..44060c3d3 100644 --- a/hott/types/hprop_trunc.hlean +++ b/hott/types/hprop_trunc.hlean @@ -31,7 +31,7 @@ namespace is_trunc { intro H x y, apply is_trunc_eq}, { intro H, cases H, apply idp}, { intro P, apply eq_of_homotopy, intro a, apply eq_of_homotopy, intro b, - esimp [function.id,compose,is_trunc_succ_intro,is_trunc_eq], + esimp [is_trunc_eq], esimp[compose,is_trunc_succ_intro], generalize (P a b), intro H, cases H, apply idp}, end diff --git a/hott/types/pointed.hlean b/hott/types/pointed.hlean index 1de0449ae..c3e13f645 100644 --- a/hott/types/pointed.hlean +++ b/hott/types/pointed.hlean @@ -206,9 +206,7 @@ namespace pointed begin revert A B f, induction n with n IH, { intros A B f, exact f}, - { intros A B f, rewrite [↑Iterated_loop_space,↓Iterated_loop_space n (Ω A), - ↑Iterated_loop_space, ↓Iterated_loop_space n (Ω B)], - apply IH (Ω A), + { intros A B f, esimp [Iterated_loop_space], apply IH (Ω A), { esimp, fconstructor, intro q, refine !respect_pt⁻¹ ⬝ ap f q ⬝ !respect_pt, esimp, apply con.left_inv}} diff --git a/library/data/real/order.lean b/library/data/real/order.lean index 6a2099a78..666319da2 100644 --- a/library/data/real/order.lean +++ b/library/data/real/order.lean @@ -725,11 +725,12 @@ theorem not_lt_self (s : seq) : ¬ s_lt s s := intro Hlt, rewrite [↑s_lt at Hlt, ↑pos at Hlt], apply exists.elim Hlt, - intro n Hn, - rewrite [↑sadd at Hn, ↑sneg at Hn, sub_self at Hn], + intro n Hn, esimp at Hn, + rewrite [↑sadd at Hn,↑sneg at Hn, sub_self at Hn], apply absurd Hn (rat.not_lt_of_ge (rat.le_of_lt !inv_pos)) end + theorem not_sep_self (s : seq) : ¬ s ≢ s := begin intro Hsep, diff --git a/library/data/stream.lean b/library/data/stream.lean index 1793a2216..135b01924 100644 --- a/library/data/stream.lean +++ b/library/data/stream.lean @@ -251,13 +251,15 @@ coinduction rfl (λ B fr ch, by rewrite [tail_iterate, tail_const]; exact ch) + +local attribute stream [reducible] theorem map_iterate (f : A → A) (a : A) : iterate f (f a) = map f (iterate f a) := begin apply funext, intro n, induction n with n' IH, {reflexivity}, - {esimp [map, iterate, nth] at *, - rewrite IH} + { esimp [map, iterate, nth] at *, + rewrite IH } end section corec diff --git a/library/logic/examples/cont.lean b/library/logic/examples/cont.lean index b8f8401c1..67e3322b8 100644 --- a/library/logic/examples/cont.lean +++ b/library/logic/examples/cont.lean @@ -59,7 +59,7 @@ lemma not_all_continuous : false := let β := znkω (M f + 1) 1 in let α := znkω m (M f + 1) in assert βeq₁ : zω =[M f + 1] β, from - λ (a : nat) (h : a < M f + 1), begin esimp [zω, znkω], rewrite [if_pos h] end, + λ (a : nat) (h : a < M f + 1), begin unfold zω, unfold znkω, rewrite [if_pos h] end, assert βeq₂ : zω =[M f] β, from pred_beq βeq₁, assert m_eq_fβ : m = f β, from M_spec f β βeq₂, assert aux : ∀ α, zω =[m] α → β 0 = β (α m), by rewrite m_eq_fβ at {1}; exact (β0_eq β), diff --git a/src/frontends/lean/parse_rewrite_tactic.cpp b/src/frontends/lean/parse_rewrite_tactic.cpp index 6080071af..d80047855 100644 --- a/src/frontends/lean/parse_rewrite_tactic.cpp +++ b/src/frontends/lean/parse_rewrite_tactic.cpp @@ -44,7 +44,7 @@ static void check_not_in_theorem_queue(parser & p, name const & n, pos_info cons } } -static expr parse_rewrite_unfold_core(parser & p) { +static expr parse_rewrite_unfold_core(parser & p, bool force_unfold) { buffer to_unfold; if (p.curr_is_token(get_lbracket_tk())) { p.next(); @@ -63,19 +63,21 @@ static expr parse_rewrite_unfold_core(parser & p) { check_not_in_theorem_queue(p, to_unfold.back(), pos); } location loc = parse_tactic_location(p); - return mk_rewrite_unfold(to_list(to_unfold), loc); + return mk_rewrite_unfold(to_list(to_unfold), force_unfold, loc); } -static expr parse_rewrite_unfold(parser & p) { +static expr parse_rewrite_unfold(parser & p, bool force_unfold) { lean_assert(p.curr_is_token(get_up_tk()) || p.curr_is_token(get_caret_tk())); p.next(); - return parse_rewrite_unfold_core(p); + return parse_rewrite_unfold_core(p, force_unfold); } // If use_paren is true, then lemmas must be identifiers or be wrapped with parenthesis static expr parse_rewrite_element(parser & p, bool use_paren) { - if (p.curr_is_token(get_up_tk()) || p.curr_is_token(get_caret_tk())) - return parse_rewrite_unfold(p); + if (p.curr_is_token(get_up_tk()) || p.curr_is_token(get_caret_tk())) { + bool force_unfold = false; + return parse_rewrite_unfold(p, force_unfold); + } if (p.curr_is_token(get_down_tk())) { p.next(); expr e = p.parse_tactic_expr_arg(); @@ -170,10 +172,11 @@ expr parse_krewrite_tactic(parser & p) { expr parse_esimp_tactic(parser & p) { buffer elems; auto pos = p.pos(); + bool force_unfold = false; if (p.curr_is_token(get_up_tk()) || p.curr_is_token(get_caret_tk())) { - elems.push_back(p.save_pos(parse_rewrite_unfold(p), pos)); + elems.push_back(p.save_pos(parse_rewrite_unfold(p, force_unfold), pos)); } else if (p.curr_is_token(get_lbracket_tk())) { - elems.push_back(p.save_pos(parse_rewrite_unfold_core(p), pos)); + elems.push_back(p.save_pos(parse_rewrite_unfold_core(p, force_unfold), pos)); } else { location loc = parse_tactic_location(p); elems.push_back(p.save_pos(mk_rewrite_reduce(loc), pos)); @@ -184,13 +187,14 @@ expr parse_esimp_tactic(parser & p) { expr parse_unfold_tactic(parser & p) { buffer elems; auto pos = p.pos(); + bool force_unfold = true; if (p.curr_is_identifier()) { name c = p.check_constant_next("invalid unfold tactic, identifier expected"); check_not_in_theorem_queue(p, c, pos); location loc = parse_tactic_location(p); - elems.push_back(p.save_pos(mk_rewrite_unfold(to_list(c), loc), pos)); + elems.push_back(p.save_pos(mk_rewrite_unfold(to_list(c), force_unfold, loc), pos)); } else if (p.curr_is_token(get_lbracket_tk())) { - elems.push_back(p.save_pos(parse_rewrite_unfold_core(p), pos)); + elems.push_back(p.save_pos(parse_rewrite_unfold_core(p, force_unfold), pos)); } else { throw parser_error("invalid unfold tactic, identifier or `[` expected", pos); } diff --git a/src/library/tactic/CMakeLists.txt b/src/library/tactic/CMakeLists.txt index 4dcf672c0..25f2ca6d2 100644 --- a/src/library/tactic/CMakeLists.txt +++ b/src/library/tactic/CMakeLists.txt @@ -6,6 +6,6 @@ expr_to_tactic.cpp location.cpp rewrite_tactic.cpp util.cpp init_module.cpp change_tactic.cpp check_expr_tactic.cpp let_tactic.cpp contradiction_tactic.cpp exfalso_tactic.cpp constructor_tactic.cpp injection_tactic.cpp congruence_tactic.cpp relation_tactics.cpp -induction_tactic.cpp subst_tactic.cpp) +induction_tactic.cpp subst_tactic.cpp unfold_rec.cpp) target_link_libraries(tactic ${LEAN_LIBS}) diff --git a/src/library/tactic/rewrite_tactic.cpp b/src/library/tactic/rewrite_tactic.cpp index d74385c76..adc99e38c 100644 --- a/src/library/tactic/rewrite_tactic.cpp +++ b/src/library/tactic/rewrite_tactic.cpp @@ -38,7 +38,7 @@ Author: Leonardo de Moura #include "library/tactic/rewrite_tactic.h" #include "library/tactic/expr_to_tactic.h" #include "library/tactic/relation_tactics.h" - +#include "library/tactic/unfold_rec.h" // #define TRACE_MATCH_PLUGIN #ifndef LEAN_DEFAULT_REWRITER_MAX_ITERATIONS @@ -57,7 +57,6 @@ Author: Leonardo de Moura #define LEAN_DEFAULT_REWRITER_BETA_ETA true #endif - namespace lean { static name * g_rewriter_max_iterations = nullptr; static name * g_rewriter_syntactic = nullptr; @@ -82,24 +81,27 @@ bool get_rewriter_beta_eta(options const & opts) { class unfold_info { list m_names; + bool m_force_unfold; location m_location; public: unfold_info() {} - unfold_info(list const & l, location const & loc):m_names(l), m_location(loc) {} + unfold_info(list const & l, bool force_unfold, location const & loc): + m_names(l), m_force_unfold(force_unfold), m_location(loc) {} list const & get_names() const { return m_names; } location const & get_location() const { return m_location; } + bool force_unfold() const { return m_force_unfold; } friend serializer & operator<<(serializer & s, unfold_info const & e) { write_list(s, e.m_names); - s << e.m_location; + s << e.m_force_unfold << e.m_location; return s; } friend deserializer & operator>>(deserializer & d, unfold_info & e) { e.m_names = read_list(d); - d >> e.m_location; + d >> e.m_force_unfold >> e.m_location; return d; } - bool operator==(unfold_info const & i) const { return m_names == i.m_names && m_location == i.m_location; } + bool operator==(unfold_info const & i) const { return m_names == i.m_names && m_location == i.m_location && m_force_unfold == i.m_force_unfold; } bool operator!=(unfold_info const & i) const { return !operator==(i); } }; @@ -308,8 +310,8 @@ public: } }; -expr mk_rewrite_unfold(list const & ns, location const & loc) { - macro_definition def(new rewrite_unfold_macro_cell(unfold_info(ns, loc))); +expr mk_rewrite_unfold(list const & ns, bool force_unfold, location const & loc) { + macro_definition def(new rewrite_unfold_macro_cell(unfold_info(ns, force_unfold, loc))); return mk_macro(def); } @@ -626,16 +628,30 @@ class rewrite_fn { } }; - optional reduce(expr const & e, list const & to_unfold) { - bool unfolded = !to_unfold; - type_checker_ptr tc(new type_checker(m_env, m_ngen.mk_child(), - std::unique_ptr(new rewriter_converter(m_env, to_unfold, unfolded)))); + optional reduce(expr const & e, list const & to_unfold, bool force_unfold) { constraint_seq cs; - bool use_eta = true; - expr r = normalize(*tc, e, cs, use_eta); - if (!unfolded || cs) // FAIL if didn't unfolded or generated constraints - return none_expr(); - return some_expr(r); + bool unfolded = !to_unfold; + bool use_eta = true; + // TODO(Leo): should we add add an option that will not try to fold recursive applications + if (to_unfold) { + auto new_e = unfold_rec(m_env, m_ngen.mk_child(), force_unfold, e, to_unfold); + if (!new_e) + return none_expr(); + type_checker_ptr tc(new type_checker(m_env, m_ngen.mk_child(), + std::unique_ptr(new rewriter_converter(m_env, list(), unfolded)))); + expr r = normalize(*tc, *new_e, cs, use_eta); + if (cs) // FAIL if generated constraints + return none_expr(); + return some_expr(r); + } else { + type_checker_ptr tc(new type_checker(m_env, m_ngen.mk_child(), + std::unique_ptr(new rewriter_converter(m_env, to_unfold, unfolded)))); + + expr r = normalize(*tc, e, cs, use_eta); + if (!unfolded || cs) // FAIL if didn't unfolded or generated constraints + return none_expr(); + return some_expr(r); + } } // Replace goal with definitionally equal one @@ -646,8 +662,8 @@ class rewrite_fn { update_goal(new_g); } - bool process_reduce_goal(list const & to_unfold) { - if (auto new_type = reduce(m_g.get_type(), to_unfold)) { + bool process_reduce_goal(list const & to_unfold, bool force_unfold) { + if (auto new_type = reduce(m_g.get_type(), to_unfold, force_unfold)) { replace_goal(*new_type); return true; } else { @@ -683,9 +699,9 @@ class rewrite_fn { update_goal(new_g); } - bool process_reduce_hypothesis(name const & hyp_internal_name, list const & to_unfold) { + bool process_reduce_hypothesis(name const & hyp_internal_name, list const & to_unfold, bool force_unfold) { expr hyp = m_g.find_hyp_from_internal_name(hyp_internal_name)->first; - if (auto new_hyp_type = reduce(mlocal_type(hyp), to_unfold)) { + if (auto new_hyp_type = reduce(mlocal_type(hyp), to_unfold, force_unfold)) { replace_hypothesis(hyp, *new_hyp_type); return true; } else { @@ -693,20 +709,20 @@ class rewrite_fn { } } - bool process_reduce_step(list const & to_unfold, location const & loc) { + bool process_reduce_step(list const & to_unfold, bool force_unfold, location const & loc) { if (loc.is_goal_only()) - return process_reduce_goal(to_unfold); + return process_reduce_goal(to_unfold, force_unfold); bool progress = false; buffer hyps; m_g.get_hyps(hyps); for (expr const & h : hyps) { if (!loc.includes_hypothesis(local_pp_name(h))) continue; - if (process_reduce_hypothesis(mlocal_name(h), to_unfold)) + if (process_reduce_hypothesis(mlocal_name(h), to_unfold, force_unfold)) progress = true; } if (loc.includes_goal()) { - if (process_reduce_goal(to_unfold)) + if (process_reduce_goal(to_unfold, force_unfold)) progress = true; } return progress; @@ -715,7 +731,7 @@ class rewrite_fn { bool process_unfold_step(expr const & elem) { lean_assert(is_rewrite_unfold_step(elem)); auto info = get_rewrite_unfold_info(elem); - return process_reduce_step(info.get_names(), info.get_location()); + return process_reduce_step(info.get_names(), info.force_unfold(), info.get_location()); } optional> elaborate_core(expr const & e, bool fail_if_cnstrs) { @@ -874,7 +890,7 @@ class rewrite_fn { lean_assert(is_rewrite_reduce_step(elem)); if (macro_num_args(elem) == 0) { auto info = get_rewrite_reduce_info(elem); - return process_reduce_step(list(), info.get_location()); + return process_reduce_step(list(), false, info.get_location()); } else { auto info = get_rewrite_reduce_info(elem); return process_reduce_to_step(macro_arg(elem, 0), info.get_location()); @@ -946,7 +962,7 @@ class rewrite_fn { expr reduce_rule_type(expr const & e) { if (m_apply_reduce) { - if (auto it = reduce(e, list())) + if (auto it = reduce(e, list(), false)) return *it; else // TODO(Leo): we should fail instead of doing trying again return head_beta_reduce(e); @@ -1257,6 +1273,10 @@ class rewrite_fn { add_target_failure(src, t, failure::Unification); return unify_result(); } + } catch (kernel_exception & ex) { + regular(m_env, m_ios) << ">> " << ex << "\n"; + add_target_failure(orig_elem, t, ex.what()); + return unify_result(); } catch (exception & ex) { add_target_failure(orig_elem, t, ex.what()); return unify_result(); @@ -1620,9 +1640,9 @@ public: lean_assert(gs); update_goal(head(gs)); m_subst = m_ps.get_subst(); - m_max_iter = get_rewriter_max_iterations(ios.get_options()); - m_use_trace = get_rewriter_trace(ios.get_options()); - m_beta_eta = get_rewriter_beta_eta(ios.get_options()); + m_max_iter = get_rewriter_max_iterations(ios.get_options()); + m_use_trace = get_rewriter_trace(ios.get_options()); + m_beta_eta = get_rewriter_beta_eta(ios.get_options()); m_apply_reduce = false; } diff --git a/src/library/tactic/rewrite_tactic.h b/src/library/tactic/rewrite_tactic.h index 95fc5f7c4..ed2d8adbd 100644 --- a/src/library/tactic/rewrite_tactic.h +++ b/src/library/tactic/rewrite_tactic.h @@ -9,7 +9,7 @@ Author: Leonardo de Moura #include "library/tactic/location.h" namespace lean { -expr mk_rewrite_unfold(list const & ns, location const & loc); +expr mk_rewrite_unfold(list const & ns, bool force_unfold, location const & loc); expr mk_rewrite_reduce(location const & loc); expr mk_rewrite_reduce_to(expr const & e, location const & loc); expr mk_rewrite_fold(expr const & e, location const & loc); diff --git a/src/library/tactic/unfold_rec.cpp b/src/library/tactic/unfold_rec.cpp new file mode 100644 index 000000000..037778eb5 --- /dev/null +++ b/src/library/tactic/unfold_rec.cpp @@ -0,0 +1,338 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "kernel/type_checker.h" +#include "kernel/abstract.h" +#include "kernel/instantiate.h" +#include "kernel/inductive/inductive.h" +#include "library/util.h" +#include "library/replace_visitor.h" +#include "library/constants.h" + +extern void pp_detail(lean::environment const & env, lean::expr const & e); +extern void pp(lean::environment const & env, lean::expr const & e); + +namespace lean { +// Auxiliary visitor the implements the common parts for unfold_rec_fn and fold_rec_fn +class replace_visitor_aux : public replace_visitor { +protected: + virtual name mk_fresh_name() = 0; + + expr visit_app_default(expr const & e, expr const & fn, buffer & args) { + bool modified = false; + for (expr & arg : args) { + expr new_arg = visit(arg); + if (arg != new_arg) + modified = true; + arg = new_arg; + } + if (!modified) + return e; + return mk_app(fn, args); + } + + virtual expr visit_binding(expr const & b) { + expr new_domain = visit(binding_domain(b)); + expr l = mk_local(mk_fresh_name(), new_domain); + expr new_body = abstract(visit(instantiate(binding_body(b), l)), l); + return update_binding(b, new_domain, new_body); + } +}; + + +class unfold_rec_fn : public replace_visitor_aux { + environment const & m_env; + name_generator m_ngen; + bool m_force_unfold; + type_checker_ptr m_tc; + type_checker_ptr m_norm_decl_tc; + list m_to_unfold; + + virtual name mk_fresh_name() { return m_ngen.next(); } + + static void throw_ill_formed() { + throw exception("ill-formed expression"); + } + + static bool is_rec_building_part(name const & n) { + if (n == get_prod_pr1_name() || n == get_prod_pr2_name()) + return true; + if (n.is_atomic() || !n.is_string()) + return false; + char const * str = n.get_string(); + return + strcmp(str, "rec_on") == 0 || + strcmp(str, "cases_on") == 0 || + strcmp(str, "brec_on") == 0 || + strcmp(str, "below") == 0 || + strcmp(str, "no_confusion") == 0; + } + + optional get_local_pos(buffer const & locals, expr const & e) { + if (!is_local(e)) + return optional(); + unsigned i = 0; + for (expr const & local : locals) { + if (mlocal_name(local) == mlocal_name(e)) + return optional(i); + i++; + } + return optional(); + } + + // return true if e is of the form (C.rec ...) + bool is_rec_app(expr const & e, buffer const & locals, name & rec_name, unsigned & main_arg_pos, buffer & rec_arg_pos) { + buffer args; + expr fn = get_app_args(e, args); + if (!is_constant(fn)) + return false; + optional I = inductive::is_elim_rule(m_env, const_name(fn)); + rec_name = const_name(fn); + if (!I) + return false; + if (!is_recursive_datatype(m_env, *I)) + return false; + unsigned major_idx = *inductive::get_elim_major_idx(m_env, const_name(fn)); + if (major_idx >= args.size()) + return false; + if (auto it = get_local_pos(locals, args[major_idx])) { + main_arg_pos = *it; + for (unsigned i = major_idx+1; i < args.size(); i++) { + if (auto it2 = get_local_pos(locals, args[i])) { + rec_arg_pos.push_back(*it2); + } else { + return false; + } + } + return true; + } + return false; + } + + enum rec_kind { BREC, REC, NOREC }; + + // try to detect the kind of recursive definition + rec_kind get_rec_kind(expr const & e, buffer const & locals, name & rec_name, unsigned & main_arg_pos, buffer & rec_arg_pos) { + if (is_rec_app(e, locals, rec_name, main_arg_pos, rec_arg_pos)) + return REC; + buffer args; + expr fn = get_app_args(e, args); + if (is_constant(fn) && const_name(fn) == inductive::get_elim_name(get_prod_name()) && + args.size() >= 5) { + // try do detect brec_on pattern + if (is_rec_app(args[4], locals, rec_name, main_arg_pos, rec_arg_pos)) { + for (unsigned i = 5; i < args.size(); i++) { + if (auto it2 = get_local_pos(locals, args[i])) { + rec_arg_pos.push_back(*it2); + } else { + return NOREC; + } + } + return BREC; + } + } + return NOREC; + } + + // just unfold the application without trying to fold recursive call + expr unfold_simple(expr const & fn, buffer & args) { + expr new_app = mk_app(fn, args); + if (auto r = unfold_term(m_env, new_app)) { + return visit(*r); + } else { + return new_app; + } + } + + expr get_fn_decl(expr const & fn, buffer & locals) { + lean_assert(is_constant(fn)); + declaration decl = m_env.get(const_name(fn)); + if (length(const_levels(fn)) != decl.get_num_univ_params()) + throw_ill_formed(); + expr fn_body = instantiate_value_univ_params(decl, const_levels(fn)); + while (is_lambda(fn_body)) { + expr local = mk_local(m_ngen.next(), binding_domain(fn_body)); + locals.push_back(local); + fn_body = instantiate(binding_body(fn_body), local); + } + return m_norm_decl_tc->whnf(fn_body).first; + } + + struct fold_failed {}; + + struct fold_rec_fn : public replace_visitor_aux { + type_checker_ptr & m_tc; + expr m_fn; // function being unfolded + buffer const & m_args; // arguments of the function being unfolded + rec_kind m_kind; + name m_rec_name; + unsigned m_major_idx; // position of the major premise in the recursor + unsigned m_main_pos; // position of the (recursive) argument in the function being unfolded + buffer const & m_rec_arg_pos; // position of the other arguments that are not fixed in the recursion + + fold_rec_fn(type_checker_ptr & tc, expr const & fn, buffer const & args, rec_kind k, name const & rec_name, + unsigned main_pos, buffer const & rec_arg_pos): + m_tc(tc), m_fn(fn), m_args(args), m_kind(k), m_rec_name(rec_name), + m_major_idx(*inductive::get_elim_major_idx(m_tc->env(), rec_name)), + m_main_pos(main_pos), m_rec_arg_pos(rec_arg_pos) { + lean_assert(m_main_pos < args.size()); + lean_assert(std::all_of(rec_arg_pos.begin(), rec_arg_pos.end(), [&](unsigned pos) { return pos < args.size(); })); + } + + virtual name mk_fresh_name() { return m_tc->mk_fresh_name(); } + + expr fold_rec(expr const & e, buffer const & args) { + if (args.size() != m_major_idx + 1 + m_rec_arg_pos.size()) + throw fold_failed(); + buffer new_args; + new_args.append(m_args); + new_args[m_main_pos] = args[m_major_idx]; + for (unsigned i = 0; i < m_rec_arg_pos.size(); i++) { + new_args[m_rec_arg_pos[i]] = args[m_major_idx + 1 + i]; + } + expr folded_app = mk_app(m_fn, new_args); + if (!m_tc->is_def_eq(folded_app, e).first) + throw fold_failed(); + return folded_app; + } + + expr fold_brec(expr const & e, buffer const & args) { + if (args.size() != 3 + m_rec_arg_pos.size()) + throw fold_failed(); + buffer nested_args; + get_app_args(args[1], nested_args); + if (nested_args.size() != m_major_idx+1) + throw fold_failed(); + buffer new_args; + new_args.append(m_args); + new_args[m_main_pos] = nested_args[m_major_idx]; + for (unsigned i = 0; i < m_rec_arg_pos.size(); i++) { + new_args[m_rec_arg_pos[i]] = args[3 + i]; + } + expr folded_app = mk_app(m_fn, new_args); + if (!m_tc->is_def_eq(folded_app, e).first) + throw fold_failed(); + return folded_app; + } + + virtual expr visit_app(expr const & e) { + buffer args; + expr fn = get_app_args(e, args); + if (m_kind == REC && is_constant(fn) && const_name(fn) == m_rec_name) + return fold_rec(e, args); + if (m_kind == BREC && is_constant(fn) && const_name(fn) == get_prod_pr1_name() && args.size() >= 3) { + expr rec_fn = get_app_fn(args[1]); + if (is_constant(rec_fn) && const_name(rec_fn) == m_rec_name) + return fold_brec(e, args); + } + return visit_app_default(e, fn, args); + } + }; + + expr unfold(expr const & fn, buffer args) { + buffer fn_locals; + expr fn_body = get_fn_decl(fn, fn_locals); + if (args.size() < fn_locals.size()) { + // insufficient args + return unfold_simple(fn, args); + } + name rec_name; + unsigned main_pos; + buffer rec_arg_pos; + rec_kind k = get_rec_kind(fn_body, fn_locals, rec_name, main_pos, rec_arg_pos); + if (k == NOREC) { + // norecursive definition + return unfold_simple(fn, args); + } + for (unsigned i = fn_locals.size(); i < args.size(); i++) + rec_arg_pos.push_back(i); + auto new_main_cs = m_tc->whnf(args[main_pos]); + if (!is_constructor_app(m_env, new_main_cs.first) || new_main_cs.second) { + // argument is not a constructor or constraints were generated + throw fold_failed(); + } + args[main_pos] = new_main_cs.first; + expr fn_body_abst = abstract_locals(fn_body, fn_locals.size(), fn_locals.data()); + expr new_e = instantiate_rev(fn_body_abst, fn_locals.size(), args.data()); + new_e = mk_app(new_e, args.size() - fn_locals.size(), args.data() + fn_locals.size()); + auto new_e_cs = m_norm_decl_tc->whnf(new_e); + if (new_e_cs.second) { + // constraints were generated + throw fold_failed(); + } + new_e = new_e_cs.first; + expr const new_head = get_app_fn(new_e); + // TODO(Leo): create an option for the following conditions? + // if (is_constant(new_head) && inductive::is_elim_rule(m_env, const_name(new_head))) { + // //head is a recursor... so the unfold is probably not generating a nice result... + // throw fold_failed(); + // } + return fold_rec_fn(m_tc, fn, args, k, rec_name, main_pos, rec_arg_pos)(new_e); + } + + bool unfold_cnst(expr const & e) { + return is_constant(e) && std::find(m_to_unfold.begin(), m_to_unfold.end(), const_name(e)) != m_to_unfold.end(); + } + + virtual expr visit_app(expr const & e) { + buffer args; + expr fn = get_app_args(e, args); + bool modified = false; + for (expr & arg : args) { + expr new_arg = visit(arg); + if (arg != new_arg) + modified = true; + arg = new_arg; + } + if (unfold_cnst(fn)) { + try { + return unfold(fn, args); + } catch (fold_failed &) { + if (m_force_unfold) + return unfold_simple(fn, args); + } + } + if (!modified) { + return e; + } else { + return mk_app(fn, args); + } + } + + virtual expr visit_constant(expr const & e) { + if (unfold_cnst(e)) { + if (auto r = unfold_term(m_env, e)) + return *r; + } + return e; + } + +public: + unfold_rec_fn(environment const & env, name_generator && ngen, bool force_unfold, list const & to_unfold): + m_env(env), + m_ngen(ngen), + m_force_unfold(force_unfold), + m_tc(mk_type_checker(m_env, m_ngen.mk_child(), [](name const &) { return false; })), + m_norm_decl_tc(mk_type_checker(m_env, m_ngen.mk_child(), [](name const & n) { return !is_rec_building_part(n); })), + m_to_unfold(to_unfold) + {} + + expr operator()(expr const & e) { + return replace_visitor_aux::operator()(e); + } +}; + +optional unfold_rec(environment const & env, name_generator && ngen, bool force_unfold, expr const & e, list const & to_unfold) { + try { + expr r = unfold_rec_fn(env, std::move(ngen), force_unfold, to_unfold)(e); + if (r == e) + return none_expr(); + return some_expr(r); + } catch (exception &) { + return none_expr(); + } +} +} diff --git a/src/library/tactic/unfold_rec.h b/src/library/tactic/unfold_rec.h new file mode 100644 index 000000000..8e21109ae --- /dev/null +++ b/src/library/tactic/unfold_rec.h @@ -0,0 +1,11 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "kernel/environment.h" +namespace lean { +optional unfold_rec(environment const & env, name_generator && ngen, bool force_unfold, expr const & e, list const & to_unfold); +} diff --git a/tests/lean/run/esimp1.lean b/tests/lean/run/esimp1.lean index e52084b96..b7ce1e4d6 100644 --- a/tests/lean/run/esimp1.lean +++ b/tests/lean/run/esimp1.lean @@ -10,12 +10,14 @@ definition foo [irreducible] (a : nat) := a example (x y : nat) (H : (fun (a : nat), sigma.pr1 ⟨foo a, y⟩) x = 0) : x = 0 := begin + esimp at H, esimp ↑foo at H, exact H end example (x y : nat) (H : x = 0) : (fun (a : nat), sigma.pr1 ⟨foo a, y⟩) x = 0 := begin + esimp, esimp ↑foo, exact H end diff --git a/tests/lean/unfold_rec.lean b/tests/lean/unfold_rec.lean new file mode 100644 index 000000000..a19b624eb --- /dev/null +++ b/tests/lean/unfold_rec.lean @@ -0,0 +1,50 @@ +import data.vector +open nat vector + +variables {A B : Type} +variable {n : nat} + +theorem tst1 : ∀ n m, succ n + succ m = succ (succ (n + m)) := +begin + intro n m, + esimp [add], + state, + rewrite [succ_add] +end + +definition add2 (x y : nat) : nat := +nat.rec_on x (λ y, y) (λ x r y, succ (r y)) y + +local infix + := add2 + +theorem tst2 : ∀ n m, succ n + succ m = succ (succ (n + m)) := +begin + intro n m, + esimp [add2], + state, + apply sorry +end + +definition fib (A : Type) : nat → nat → nat → nat +| b 0 c := b +| b 1 c := c +| b (succ (succ a)) c := fib b a c + fib b (succ a) c + +theorem fibgt0 : ∀ b n c, fib nat b n c > 0 +| b 0 c := sorry +| b 1 c := sorry +| b (succ (succ m)) c := +begin + unfold fib, + state, + apply sorry +end + +theorem unzip_zip : ∀ {n : nat} (v₁ : vector A n) (v₂ : vector B n), unzip (zip v₁ v₂) = (v₁, v₂) +| 0 [] [] := rfl +| (succ m) (a::va) (b::vb) := +begin + unfold [zip, unzip], + state, + rewrite [unzip_zip] +end diff --git a/tests/lean/unfold_rec.lean.expected.out b/tests/lean/unfold_rec.lean.expected.out new file mode 100644 index 000000000..0961d8d4d --- /dev/null +++ b/tests/lean/unfold_rec.lean.expected.out @@ -0,0 +1,20 @@ +unfold_rec.lean:11:2: proof state +n m : ℕ +⊢ succ (succ n + m) = succ (succ (n + m)) +unfold_rec.lean:24:2: proof state +n m : ℕ +⊢ succ (n + succ m) = succ (succ (n + m)) +unfold_rec.lean:39:2: proof state +fibgt0 : ∀ (b n c : ℕ), fib ℕ b n c > 0, +b m c : ℕ +⊢ fib ℕ b m c + fib ℕ b (succ m) c > 0 +unfold_rec.lean:48:2: proof state +A : Type, +B : Type, +unzip_zip : ∀ {n : ℕ} (v₁ : vector A n) (v₂ : vector B n), unzip (zip v₁ v₂) = (v₁, v₂), +m : ℕ, +a : A, +va : vector A m, +b : B, +vb : vector B m +⊢ (a :: prod.pr1 (unzip (zip va vb)), b :: prod.pr2 (unzip (zip va vb))) = (a :: va, b :: vb)