diff --git a/src/emacs/lean-syntax.el b/src/emacs/lean-syntax.el index 2c4e0ac7f..9c995213a 100644 --- a/src/emacs/lean-syntax.el +++ b/src/emacs/lean-syntax.el @@ -118,7 +118,7 @@ (,(rx symbol-start "_" symbol-end) . 'font-lock-preprocessor-face) ;; modifiers (,(rx (or "\[persistent\]" "\[notation\]" "\[visible\]" "\[instance\]" "\[class\]" "\[parsing-only\]" - "\[coercion\]" "\[reducible\]" "\[irreducible\]" "\[semireducible\]" "\[quasireducible\]" "\[wf\]" + "\[coercion\]" "\[unfold-f\]" "\[reducible\]" "\[irreducible\]" "\[semireducible\]" "\[quasireducible\]" "\[wf\]" "\[whnf\]" "\[multiple-instances\]" "\[none\]" "\[decls\]" "\[declarations\]" "\[all-transparent\]" "\[coercions\]" "\[classes\]" "\[notations\]" "\[abbreviations\]" "\[begin-end-hints\]" "\[tactic-hints\]" "\[reduce-hints\]")) diff --git a/src/frontends/lean/decl_cmds.cpp b/src/frontends/lean/decl_cmds.cpp index 47c0adcef..3246096e3 100644 --- a/src/frontends/lean/decl_cmds.cpp +++ b/src/frontends/lean/decl_cmds.cpp @@ -300,6 +300,7 @@ struct decl_attributes { bool m_is_class; bool m_is_parsing_only; bool m_has_multiple_instances; + bool m_unfold_f_hint; optional m_priority; optional m_unfold_c_hint; @@ -317,6 +318,7 @@ struct decl_attributes { m_is_class = false; m_is_parsing_only = false; m_has_multiple_instances = false; + m_unfold_f_hint = false; } struct elim_choice_fn : public replace_visitor { @@ -421,6 +423,9 @@ struct decl_attributes { "marked as '[parsing-only]'", pos); m_is_parsing_only = true; p.next(); + } else if (p.curr_is_token(get_unfold_f_tk())) { + p.next(); + m_unfold_f_hint = true; } else if (p.curr_is_token(get_unfold_c_tk())) { p.next(); unsigned r = p.parse_small_nat(); @@ -469,7 +474,9 @@ struct decl_attributes { if (m_is_quasireducible) env = set_reducible(env, d, reducible_status::Quasireducible, m_persistent); if (m_unfold_c_hint) - env = add_unfold_c_hint(env, d, m_unfold_c_hint, m_persistent); + env = add_unfold_c_hint(env, d, *m_unfold_c_hint, m_persistent); + if (m_unfold_f_hint) + env = add_unfold_f_hint(env, d, m_persistent); } if (m_is_class) env = add_class(env, d, m_persistent); diff --git a/src/frontends/lean/token_table.cpp b/src/frontends/lean/token_table.cpp index ff2559e8b..51f4a0e9c 100644 --- a/src/frontends/lean/token_table.cpp +++ b/src/frontends/lean/token_table.cpp @@ -93,7 +93,7 @@ void init_token_table(token_table & t) { "variables", "parameter", "parameters", "constant", "constants", "[persistent]", "[visible]", "[instance]", "[none]", "[class]", "[coercion]", "[reducible]", "[irreducible]", "[semireducible]", "[quasireducible]", "[parsing-only]", "[multiple-instances]", - "evaluate", "check", "eval", "[wf]", "[whnf]", "[all-transparent]", "[priority", "[unfold-c", "print", + "evaluate", "check", "eval", "[wf]", "[whnf]", "[all-transparent]", "[priority", "[unfold-f]", "[unfold-c", "print", "end", "namespace", "section", "prelude", "help", "import", "inductive", "record", "structure", "module", "universe", "universes", "local", "precedence", "reserve", "infixl", "infixr", "infix", "postfix", "prefix", "notation", "context", diff --git a/src/frontends/lean/tokens.cpp b/src/frontends/lean/tokens.cpp index 99a37f2e7..9b0a62995 100644 --- a/src/frontends/lean/tokens.cpp +++ b/src/frontends/lean/tokens.cpp @@ -102,6 +102,7 @@ static name * g_opaque = nullptr; static name * g_instance = nullptr; static name * g_priority = nullptr; static name * g_unfold_c = nullptr; +static name * g_unfold_f = nullptr; static name * g_coercion = nullptr; static name * g_reducible = nullptr; static name * g_quasireducible = nullptr; @@ -227,6 +228,7 @@ void initialize_tokens() { g_instance = new name("[instance]"); g_priority = new name("[priority"); g_unfold_c = new name("[unfold-c"); + g_unfold_f = new name("[unfold-f]"); g_coercion = new name("[coercion]"); g_reducible = new name("[reducible]"); g_quasireducible = new name("[quasireducible]"); @@ -290,6 +292,7 @@ void finalize_tokens() { delete g_instance; delete g_priority; delete g_unfold_c; + delete g_unfold_f; delete g_coercion; delete g_reducible; delete g_quasireducible; @@ -479,6 +482,7 @@ name const & get_opaque_tk() { return *g_opaque; } name const & get_instance_tk() { return *g_instance; } name const & get_priority_tk() { return *g_priority; } name const & get_unfold_c_tk() { return *g_unfold_c; } +name const & get_unfold_f_tk() { return *g_unfold_f; } name const & get_coercion_tk() { return *g_coercion; } name const & get_reducible_tk() { return *g_reducible; } name const & get_quasireducible_tk() { return *g_quasireducible; } diff --git a/src/frontends/lean/tokens.h b/src/frontends/lean/tokens.h index cc7e0e8b1..4a1c1ab44 100644 --- a/src/frontends/lean/tokens.h +++ b/src/frontends/lean/tokens.h @@ -104,6 +104,7 @@ name const & get_opaque_tk(); name const & get_instance_tk(); name const & get_priority_tk(); name const & get_unfold_c_tk(); +name const & get_unfold_f_tk(); name const & get_coercion_tk(); name const & get_reducible_tk(); name const & get_semireducible_tk(); diff --git a/src/library/normalize.cpp b/src/library/normalize.cpp index 3b13a41aa..a2f825d77 100644 --- a/src/library/normalize.cpp +++ b/src/library/normalize.cpp @@ -19,42 +19,64 @@ Author: Leonardo de Moura namespace lean { /** - \brief c_unfold hint instructs the normalizer (and simplifier) that - a function application (f a_1 ... a_i ... a_n) should be unfolded - when argument a_i is a constructor. + \brief unfold hints instruct the normalizer (and simplifier) that + a function application. We have two kinds of hints: + - unfold_c (f a_1 ... a_i ... a_n) should be unfolded + when argument a_i is a constructor. + - unfold_f (f a_1 ... a_i ... a_n) should be unfolded when it is fully applied. */ -struct unfold_c_hint_entry { - name m_decl_name; - optional m_arg_idx; - unfold_c_hint_entry() {} - unfold_c_hint_entry(name const & n, optional const & idx):m_decl_name(n), m_arg_idx(idx) {} +struct unfold_hint_entry { + bool m_unfold_c; //!< true if it is an unfold_c hint + bool m_add; //!< add/remove hint + name m_decl_name; + unsigned m_arg_idx; + unfold_hint_entry():m_unfold_c(false), m_add(false), m_arg_idx(0) {} + unfold_hint_entry(bool unfold_c, bool add, name const & n, unsigned idx): + m_unfold_c(unfold_c), m_add(add), m_decl_name(n), m_arg_idx(idx) {} }; -static name * g_unfold_c_hint_name = nullptr; +unfold_hint_entry mk_add_unfold_c_entry(name const & n, unsigned idx) { return unfold_hint_entry(true, true, n, idx); } +unfold_hint_entry mk_erase_unfold_c_entry(name const & n) { return unfold_hint_entry(true, false, n, 0); } +unfold_hint_entry mk_add_unfold_f_entry(name const & n) { return unfold_hint_entry(false, true, n, 0); } +unfold_hint_entry mk_erase_unfold_f_entry(name const & n) { return unfold_hint_entry(false, false, n, 0); } + +static name * g_unfold_hint_name = nullptr; static std::string * g_key = nullptr; -struct unfold_c_hint_config { - typedef name_map state; - typedef unfold_c_hint_entry entry; +struct unfold_hint_state { + name_map m_unfold_c; + name_set m_unfold_f; +}; + +struct unfold_hint_config { + typedef unfold_hint_state state; + typedef unfold_hint_entry entry; static void add_entry(environment const &, io_state const &, state & s, entry const & e) { - if (e.m_arg_idx) - s.insert(e.m_decl_name, *e.m_arg_idx); - else - s.erase(e.m_decl_name); + if (e.m_unfold_c) { + if (e.m_add) + s.m_unfold_c.insert(e.m_decl_name, e.m_arg_idx); + else + s.m_unfold_c.erase(e.m_decl_name); + } else { + if (e.m_add) + s.m_unfold_f.insert(e.m_decl_name); + else + s.m_unfold_f.erase(e.m_decl_name); + } } static name const & get_class_name() { - return *g_unfold_c_hint_name; + return *g_unfold_hint_name; } static std::string const & get_serialization_key() { return *g_key; } static void write_entry(serializer & s, entry const & e) { - s << e.m_decl_name << e.m_arg_idx; + s << e.m_unfold_c << e.m_add << e.m_decl_name << e.m_arg_idx; } static entry read_entry(deserializer & d) { entry e; - d >> e.m_decl_name >> e.m_arg_idx; + d >> e.m_unfold_c >> e.m_add >> e.m_decl_name >> e.m_arg_idx; return e; } static optional get_fingerprint(entry const & e) { @@ -62,33 +84,53 @@ struct unfold_c_hint_config { } }; -template class scoped_ext; -typedef scoped_ext unfold_c_hint_ext; +template class scoped_ext; +typedef scoped_ext unfold_hint_ext; -environment add_unfold_c_hint(environment const & env, name const & n, optional idx, bool persistent) { +environment add_unfold_c_hint(environment const & env, name const & n, unsigned idx, bool persistent) { declaration const & d = env.get(n); if (!d.is_definition() || d.is_opaque()) throw exception("invalid unfold-c hint, declaration must be a non-opaque definition"); - return unfold_c_hint_ext::add_entry(env, get_dummy_ios(), unfold_c_hint_entry(n, idx), persistent); + return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_add_unfold_c_entry(n, idx), persistent); } optional has_unfold_c_hint(environment const & env, name const & d) { - name_map const & s = unfold_c_hint_ext::get_state(env); - if (auto it = s.find(d)) + unfold_hint_state const & s = unfold_hint_ext::get_state(env); + if (auto it = s.m_unfold_c.find(d)) return optional(*it); else return optional(); } +environment erase_unfold_c_hint(environment const & env, name const & n, bool persistent) { + return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_erase_unfold_c_entry(n), persistent); +} + +environment add_unfold_f_hint(environment const & env, name const & n, bool persistent) { + declaration const & d = env.get(n); + if (!d.is_definition() || d.is_opaque()) + throw exception("invalid unfold-f hint, declaration must be a non-opaque definition"); + return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_add_unfold_f_entry(n), persistent); +} + +bool has_unfold_f_hint(environment const & env, name const & d) { + unfold_hint_state const & s = unfold_hint_ext::get_state(env); + return s.m_unfold_f.contains(d); +} + +environment erase_unfold_f_hint(environment const & env, name const & n, bool persistent) { + return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_erase_unfold_f_entry(n), persistent); +} + void initialize_normalize() { - g_unfold_c_hint_name = new name("c-unfold"); - g_key = new std::string("c-unfold"); - unfold_c_hint_ext::initialize(); + g_unfold_hint_name = new name("unfold-hints"); + g_key = new std::string("unfoldh"); + unfold_hint_ext::initialize(); } void finalize_normalize() { - unfold_c_hint_ext::finalize(); - delete g_unfold_c_hint_name; + unfold_hint_ext::finalize(); + delete g_unfold_hint_name; delete g_key; } @@ -113,6 +155,10 @@ class normalize_fn { return ::lean::has_unfold_c_hint(m_tc.env(), const_name(f)); } + bool has_unfold_f_hint(expr const & f) { + return is_constant(f) && ::lean::has_unfold_f_hint(m_tc.env(), const_name(f)); + } + expr normalize_app(expr const & e) { buffer args; bool modified = false; @@ -123,6 +169,12 @@ class normalize_fn { modified = true; a = new_a; } + if (has_unfold_f_hint(f)) { + if (!is_pi(m_tc.whnf(m_tc.infer(e).first).first)) { + if (optional r = unfold_app(m_tc.env(), mk_rev_app(f, args))) + return normalize(*r); + } + } if (auto idx = has_unfold_c_hint(f)) { if (*idx < args.size() && is_constructor_app(m_tc.env(), args[args.size() - *idx - 1])) { if (optional r = unfold_app(m_tc.env(), mk_rev_app(f, args))) diff --git a/src/library/normalize.h b/src/library/normalize.h index 216163943..7da2a0afa 100644 --- a/src/library/normalize.h +++ b/src/library/normalize.h @@ -32,27 +32,26 @@ expr normalize(type_checker & tc, expr const & e, constraint_seq & cs, bool eta expr normalize(type_checker & tc, expr const & e, std::function const & pred, // NOLINT constraint_seq & cs, bool eta = false); -/** \brief c_unfold hint instructs the normalizer (and simplifier) that +/** \brief unfold-c hint instructs the normalizer (and simplifier) that a function application (f a_1 ... a_i ... a_n) should be unfolded when argument a_i is a constructor. The constant will be unfolded even if it the whnf procedure did not unfolded it. Of course, kernel opaque constants are not unfolded. - - \remark If idx is none, then the hint is removed. */ -environment add_unfold_c_hint(environment const & env, name const & n, optional idx, bool persistent = true); -inline environment add_unfold_c_hint(environment const & env, name const & n, unsigned idx, bool persistent = true) { - return add_unfold_c_hint(env, n, optional(idx), persistent); -} -inline environment erase_unfold_c_hint(environment const & env, name const & n, bool persistent = true) { - return add_unfold_c_hint(env, n, optional(), persistent); -} - +environment add_unfold_c_hint(environment const & env, name const & n, unsigned idx, bool persistent = true); +environment erase_unfold_c_hint(environment const & env, name const & n, bool persistent = true); /** \brief Retrieve the hint added with the procedure add_unfold_c_hint. */ optional has_unfold_c_hint(environment const & env, name const & d); +/** \brief unfold-f hint instructs normalizer (and simplifier) that function application + (f a_1 ... a_n) should be unfolded when it is fully applied */ +environment add_unfold_f_hint(environment const & env, name const & n, bool persistent = true); +environment erase_unfold_f_hint(environment const & env, name const & n, bool persistent = true); +/** \brief Retrieve the hint added with the procedure add_unfold_f_hint. */ +optional has_unfold_f_hint(environment const & env, name const & d); + void initialize_normalize(); void finalize_normalize(); } diff --git a/tests/lean/unfoldf.lean b/tests/lean/unfoldf.lean new file mode 100644 index 000000000..8deb1c48c --- /dev/null +++ b/tests/lean/unfoldf.lean @@ -0,0 +1,28 @@ +open nat + +definition id [unfold-f] {A : Type} (a : A) := a +definition compose {A B C : Type} (g : B → C) (f : A → B) (a : A) := g (f a) +notation g ∘ f := compose g f + +example (a b : nat) (H : a = b) : id a = b := +begin + esimp, + state, + exact H +end + +example (a b : nat) (H : a = b) : (id ∘ id) a = b := +begin + esimp, + state, + exact H +end + +attribute compose [unfold-f] + +example (a b : nat) (H : a = b) : (id ∘ id) a = b := +begin + esimp, + state, + exact H +end diff --git a/tests/lean/unfoldf.lean.expected.out b/tests/lean/unfoldf.lean.expected.out new file mode 100644 index 000000000..68e8113e0 --- /dev/null +++ b/tests/lean/unfoldf.lean.expected.out @@ -0,0 +1,12 @@ +unfoldf.lean:10:2: proof state +a b : ℕ, +H : a = b +⊢ a = b +unfoldf.lean:17:2: proof state +a b : ℕ, +H : a = b +⊢ (id ∘ id) a = b +unfoldf.lean:26:2: proof state +a b : ℕ, +H : a = b +⊢ a = b