diff --git a/src/frontends/lean/decl_attributes.cpp b/src/frontends/lean/decl_attributes.cpp index 43c620ea6..5ab4a7c58 100644 --- a/src/frontends/lean/decl_attributes.cpp +++ b/src/frontends/lean/decl_attributes.cpp @@ -112,11 +112,17 @@ void decl_attributes::parse(buffer const & ns, parser & p) { m_constructor_hint = true; } else if (p.curr_is_token(get_unfold_tk())) { p.next(); - unsigned r = p.parse_small_nat(); - if (r == 0) - throw parser_error("invalid '[unfold]' attribute, value must be greater than 0", pos); - m_unfold_hint = r - 1; - p.check_token_next(get_rbracket_tk(), "invalid 'unfold', ']' expected"); + buffer idxs; + while (true) { + unsigned r = p.parse_small_nat(); + if (r == 0) + throw parser_error("invalid '[unfold]' attribute, value must be greater than 0", pos); + idxs.push_back(r-1); + if (p.curr_is_token(get_rbracket_tk())) + break; + } + p.next(); + m_unfold_hint = to_list(idxs); } else if (p.curr_is_token(get_symm_tk())) { p.next(); m_symm = true; @@ -193,7 +199,7 @@ environment decl_attributes::apply(environment env, io_state const & ios, name c if (m_is_quasireducible) env = set_reducible(env, d, reducible_status::Quasireducible, m_persistent); if (m_unfold_hint) - env = add_unfold_hint(env, d, *m_unfold_hint, m_persistent); + env = add_unfold_hint(env, d, m_unfold_hint, m_persistent); if (m_unfold_full_hint) env = add_unfold_full_hint(env, d, m_persistent); } @@ -223,7 +229,8 @@ void decl_attributes::write(serializer & s) const { << m_is_reducible << m_is_irreducible << m_is_semireducible << m_is_quasireducible << m_is_class << m_is_parsing_only << m_has_multiple_instances << m_unfold_full_hint << m_constructor_hint << m_symm << m_trans << m_refl << m_subst << m_recursor - << m_rewrite << m_recursor_major_pos << m_priority << m_unfold_hint; + << m_rewrite << m_recursor_major_pos << m_priority; + write_list(s, m_unfold_hint); } void decl_attributes::read(deserializer & d) { @@ -231,6 +238,7 @@ void decl_attributes::read(deserializer & d) { >> m_is_reducible >> m_is_irreducible >> m_is_semireducible >> m_is_quasireducible >> m_is_class >> m_is_parsing_only >> m_has_multiple_instances >> m_unfold_full_hint >> m_constructor_hint >> m_symm >> m_trans >> m_refl >> m_subst >> m_recursor - >> m_rewrite >> m_recursor_major_pos >> m_priority >> m_unfold_hint; + >> m_rewrite >> m_recursor_major_pos >> m_priority; + m_unfold_hint = read_list(d); } } diff --git a/src/frontends/lean/decl_attributes.h b/src/frontends/lean/decl_attributes.h index 5473dae89..d9a6d6340 100644 --- a/src/frontends/lean/decl_attributes.h +++ b/src/frontends/lean/decl_attributes.h @@ -32,7 +32,7 @@ class decl_attributes { bool m_rewrite; optional m_recursor_major_pos; optional m_priority; - optional m_unfold_hint; + list m_unfold_hint; void parse(name const & n, parser & p); public: diff --git a/src/frontends/lean/structure_cmd.cpp b/src/frontends/lean/structure_cmd.cpp index 0c42c5d7b..1a9dfb019 100644 --- a/src/frontends/lean/structure_cmd.cpp +++ b/src/frontends/lean/structure_cmd.cpp @@ -724,8 +724,8 @@ struct structure_cmd_fn { rec_on_decl.get_type(), rec_on_decl.get_value()); m_env = module::add(m_env, check(m_env, new_decl)); m_env = set_reducible(m_env, n, reducible_status::Reducible); - if (optional idx = has_unfold_hint(m_env, rec_on_name)) - m_env = add_unfold_hint(m_env, n, *idx); + if (list idx = has_unfold_hint(m_env, rec_on_name)) + m_env = add_unfold_hint(m_env, n, idx); save_def_info(n); add_alias(n); } diff --git a/src/library/normalize.cpp b/src/library/normalize.cpp index 20b9ff412..200ced26d 100644 --- a/src/library/normalize.cpp +++ b/src/library/normalize.cpp @@ -30,29 +30,31 @@ namespace lean { */ struct unfold_hint_entry { enum kind {Unfold, UnfoldFull, Constructor}; - kind m_kind; //!< true if it is an unfold_c hint - bool m_add; //!< add/remove hint - name m_decl_name; - unsigned m_arg_idx; - unfold_hint_entry():m_kind(Unfold), m_add(false), m_arg_idx(0) {} - unfold_hint_entry(kind k, bool add, name const & n, unsigned idx): - m_kind(k), m_add(add), m_decl_name(n), m_arg_idx(idx) {} + kind m_kind; //!< true if it is an unfold_c hint + bool m_add; //!< add/remove hint + name m_decl_name; + list m_arg_idxs; //!< only relevant if m_kind == Unfold + unfold_hint_entry():m_kind(Unfold), m_add(false) {} + unfold_hint_entry(kind k, bool add, name const & n): + m_kind(k), m_add(add), m_decl_name(n) {} + unfold_hint_entry(bool add, name const & n, list const & idxs): + m_kind(Unfold), m_add(add), m_decl_name(n), m_arg_idxs(idxs) {} }; -unfold_hint_entry mk_add_unfold_entry(name const & n, unsigned idx) { return unfold_hint_entry(unfold_hint_entry::Unfold, true, n, idx); } -unfold_hint_entry mk_erase_unfold_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Unfold, false, n, 0); } -unfold_hint_entry mk_add_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, true, n, 0); } -unfold_hint_entry mk_erase_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, false, n, 0); } -unfold_hint_entry mk_add_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, true, n, 0); } -unfold_hint_entry mk_erase_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, false, n, 0); } +unfold_hint_entry mk_add_unfold_entry(name const & n, list const & idxs) { return unfold_hint_entry(true, n, idxs); } +unfold_hint_entry mk_erase_unfold_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Unfold, false, n); } +unfold_hint_entry mk_add_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, true, n); } +unfold_hint_entry mk_erase_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, false, n); } +unfold_hint_entry mk_add_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, true, n); } +unfold_hint_entry mk_erase_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, false, n); } static name * g_unfold_hint_name = nullptr; static std::string * g_key = nullptr; struct unfold_hint_state { - name_map m_unfold; - name_set m_unfold_full; - name_set m_constructor; + name_map> m_unfold; + name_set m_unfold_full; + name_set m_constructor; }; struct unfold_hint_config { @@ -63,7 +65,7 @@ struct unfold_hint_config { switch (e.m_kind) { case unfold_hint_entry::Unfold: if (e.m_add) - s.m_unfold.insert(e.m_decl_name, e.m_arg_idx); + s.m_unfold.insert(e.m_decl_name, e.m_arg_idxs); else s.m_unfold.erase(e.m_decl_name); break; @@ -88,13 +90,17 @@ struct unfold_hint_config { return *g_key; } static void write_entry(serializer & s, entry const & e) { - s << static_cast(e.m_kind) << e.m_add << e.m_decl_name << e.m_arg_idx; + s << static_cast(e.m_kind) << e.m_add << e.m_decl_name; + if (e.m_kind == unfold_hint_entry::Unfold) + write_list(s, e.m_arg_idxs); } static entry read_entry(deserializer & d) { char k; entry e; - d >> k >> e.m_add >> e.m_decl_name >> e.m_arg_idx; + d >> k >> e.m_add >> e.m_decl_name; e.m_kind = static_cast(k); + if (e.m_kind == unfold_hint_entry::Unfold) + e.m_arg_idxs = read_list(d); return e; } static optional get_fingerprint(entry const & e) { @@ -105,19 +111,20 @@ struct unfold_hint_config { template class scoped_ext; typedef scoped_ext unfold_hint_ext; -environment add_unfold_hint(environment const & env, name const & n, unsigned idx, bool persistent) { +environment add_unfold_hint(environment const & env, name const & n, list const & idxs, bool persistent) { + lean_assert(idxs); declaration const & d = env.get(n); if (!d.is_definition()) throw exception("invalid [unfold] hint, declaration must be a non-opaque definition"); - return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_add_unfold_entry(n, idx), persistent); + return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_add_unfold_entry(n, idxs), persistent); } -optional has_unfold_hint(environment const & env, name const & d) { +list has_unfold_hint(environment const & env, name const & d) { unfold_hint_state const & s = unfold_hint_ext::get_state(env); if (auto it = s.m_unfold.find(d)) - return optional(*it); + return list(*it); else - return optional(); + return list(); } environment erase_unfold_hint(environment const & env, name const & n, bool persistent) { @@ -246,9 +253,9 @@ class normalize_fn { return update_binding(e, d, b); } - optional has_unfold_hint(expr const & f) { + list has_unfold_hint(expr const & f) { if (!is_constant(f)) - return optional(); + return list(); return ::lean::has_unfold_hint(env(), const_name(f)); } @@ -270,27 +277,39 @@ class normalize_fn { } } - optional unfold_recursor_core(expr const & f, unsigned idx, buffer & args, bool is_rec) { - if (idx < args.size()) { + optional unfold_recursor_core(expr const & f, unsigned i, + buffer const & idxs, buffer & args, bool is_rec) { + if (i == idxs.size()) { + expr new_app = mk_rev_app(f, args); + if (is_rec) + return some_expr(normalize(new_app)); + else if (optional r = unfold_app(env(), new_app)) + return some_expr(normalize(*r)); + else + return none_expr(); + } else { + unsigned idx = idxs[i]; + if (idx >= args.size()) + return none_expr(); expr & arg = args[args.size() - idx - 1]; - if (optional new_arg = is_constructor_like(arg)) { - flet set_arg(arg, *new_arg); - expr new_app = mk_rev_app(f, args); - if (is_rec) - return some_expr(normalize(new_app)); - else if (optional r = unfold_app(env(), new_app)) - return some_expr(normalize(*r)); - } + optional new_arg = is_constructor_like(arg); + if (!new_arg) + return none_expr(); + flet set_arg(arg, *new_arg); + return unfold_recursor_core(f, i+1, idxs, args, is_rec); } - return none_expr(); } - optional unfold_recursor_like(expr const & f, unsigned idx, buffer & args) { - return unfold_recursor_core(f, idx, args, false); + optional unfold_recursor_like(expr const & f, list const & idx_lst, buffer & args) { + buffer idxs; + to_buffer(idx_lst, idxs); + return unfold_recursor_core(f, 0, idxs, args, false); } optional unfold_recursor_major(expr const & f, unsigned idx, buffer & args) { - return unfold_recursor_core(f, idx, args, true); + buffer idxs; + idxs.push_back(idx); + return unfold_recursor_core(f, 0, idxs, args, true); } expr normalize_app(expr const & e) { @@ -311,8 +330,8 @@ class normalize_fn { return normalize(*r); } } - if (auto idx = has_unfold_hint(f)) { - if (auto r = unfold_recursor_like(f, *idx, args)) + if (auto idxs = has_unfold_hint(f)) { + if (auto r = unfold_recursor_like(f, idxs, args)) return *r; } if (is_constant(f)) { diff --git a/src/library/normalize.h b/src/library/normalize.h index 4cc331c71..68c760d60 100644 --- a/src/library/normalize.h +++ b/src/library/normalize.h @@ -40,10 +40,10 @@ expr normalize(type_checker & tc, expr const & e, std::function const & idxs, bool persistent = true); environment erase_unfold_hint(environment const & env, name const & n, bool persistent = true); /** \brief Retrieve the hint added with the procedure add_unfold_hint. */ -optional has_unfold_hint(environment const & env, name const & d); +list has_unfold_hint(environment const & env, name const & d); /** \brief [unfold-full] hint instructs normalizer (and simplifier) that function application (f a_1 ... a_n) should be unfolded when it is fully applied */ diff --git a/tests/lean/693.lean b/tests/lean/693.lean new file mode 100644 index 000000000..ef3008ecb --- /dev/null +++ b/tests/lean/693.lean @@ -0,0 +1,36 @@ +open nat + +definition foo [unfold 1 3] (a : nat) (b : nat) (c :nat) : nat := +(a + c) * b + +example (c : nat) : c = 1 → foo 1 c 0 = foo 1 1 0 := +begin + intro h, + esimp, + state, + subst c +end + +example (b c : nat) : c = 1 → foo 1 c b = foo 1 1 b := +begin + intro h, + esimp, -- should not unfold foo + state, + subst c +end + +example (b c : nat) : c = 1 → foo b c 0 = foo b 1 0 := +begin + intro h, + esimp, -- should not unfold foo + state, + subst c +end + +example (b c : nat) : c = 1 → foo 1 c 1 = foo c 1 1 := +begin + intro h, + esimp, -- should fold only first foo + state, + subst c +end diff --git a/tests/lean/693.lean.expected.out b/tests/lean/693.lean.expected.out new file mode 100644 index 000000000..1e7240e2f --- /dev/null +++ b/tests/lean/693.lean.expected.out @@ -0,0 +1,16 @@ +693.lean:10:2: proof state +c : ℕ, +h : c = 1 +⊢ (1 + 0) * c = (1 + 0) * 1 +693.lean:18:2: proof state +b c : ℕ, +h : c = 1 +⊢ foo 1 c b = foo 1 1 b +693.lean:26:2: proof state +b c : ℕ, +h : c = 1 +⊢ foo b c 0 = foo b 1 0 +693.lean:34:2: proof state +b c : ℕ, +h : c = 1 +⊢ (1 + 1) * c = foo c 1 1