diff --git a/src/frontends/lean/parse_table.cpp b/src/frontends/lean/parse_table.cpp index 3358f19f1..0aad72210 100644 --- a/src/frontends/lean/parse_table.cpp +++ b/src/frontends/lean/parse_table.cpp @@ -359,10 +359,6 @@ static bool is_simple(unsigned num, transition const * ts) { return std::all_of(ts, ts+num, [](transition const & t) { return t.is_simple(); }); } -bool is_safe_ascii(unsigned num, transition const * ts) { - return std::all_of(ts, ts+num, [](transition const & t) { return t.is_safe_ascii(); }); -} - /** \brief Given \c a, an expression that is the denotation of an expression, if \c a is a variable, then use the actions in the transitions \c ts to expand \c a. The idea is to produce a head symbol we can use to decide whether the notation should be considered during pretty printing. diff --git a/src/frontends/lean/parse_table.h b/src/frontends/lean/parse_table.h index f55861d5c..c7d04f04d 100644 --- a/src/frontends/lean/parse_table.h +++ b/src/frontends/lean/parse_table.h @@ -83,6 +83,8 @@ public: /** \brief Return true iff the action is not Ext or LuaExt */ bool is_simple() const; }; +inline bool operator==(action const & a1, action const & a2) { return a1.is_equal(a2); } +inline bool operator!=(action const & a1, action const & a2) { return !a1.is_equal(a2); } action mk_skip_action(); action mk_expr_action(unsigned rbp = 0); @@ -108,8 +110,12 @@ public: bool is_simple() const { return m_action.is_simple(); } bool is_safe_ascii() const { return m_token.is_safe_ascii(); } }; - -bool is_safe_ascii(unsigned num, transition const * ts); +inline bool operator==(transition const & t1, transition const & t2) { + return t1.get_token() == t2.get_token() && t1.get_action() == t2.get_action(); +} +inline bool operator!=(transition const & t1, transition const & t2) { + return !(t1 == t2); +} /** \brief Apply \c f to expressions embedded in the given transition */ transition replace(transition const & t, std::function const & f); diff --git a/src/frontends/lean/parser_config.cpp b/src/frontends/lean/parser_config.cpp index 1944d3bd0..86de01edc 100644 --- a/src/frontends/lean/parser_config.cpp +++ b/src/frontends/lean/parser_config.cpp @@ -30,7 +30,7 @@ notation_entry replace(notation_entry const & e, std::function const & ts, expr co m_kind(is_nud ? notation_entry_kind::NuD : notation_entry_kind::LeD), m_expr(e), m_overload(overload) { new (&m_transitions) list(ts); + m_safe_ascii = std::all_of(ts.begin(), ts.end(), [](transition const & t) { return t.is_safe_ascii(); }); } notation_entry::notation_entry(notation_entry const & e, bool overload): notation_entry(e) { m_overload = overload; } notation_entry::notation_entry(mpz const & val, expr const & e, bool overload): - m_kind(notation_entry_kind::Numeral), m_expr(e), m_overload(overload) { + m_kind(notation_entry_kind::Numeral), m_expr(e), m_overload(overload), m_safe_ascii(true) { new (&m_num) mpz(val); } @@ -58,6 +59,15 @@ notation_entry::~notation_entry() { m_transitions.~list(); } +bool operator==(notation_entry const & e1, notation_entry const & e2) { + if (e1.kind() != e2.kind() || e1.overload() != e2.overload() || e1.get_expr() != e2.get_expr()) + return false; + if (e1.is_numeral()) + return e1.get_num() == e2.get_num(); + else + return e1.get_transitions() == e2.get_transitions(); +} + struct token_state { token_table m_table; token_state() { m_table = mk_default_token_table(); } @@ -183,9 +193,12 @@ transition read_transition(deserializer & d) { struct notation_state { typedef rb_map, mpz_cmp_fn> num_map; + typedef head_map head_to_entries; parse_table m_nud; parse_table m_led; num_map m_num_map; + head_to_entries m_inv_map; + notation_state() { m_nud = get_builtin_nud_table(); m_led = get_builtin_led_table(); @@ -198,18 +211,28 @@ struct notation_config { static name * g_class_name; static std::string * g_key; + static void updt_inv_map(state & s, head_index const & idx, entry const & e) { + s.m_inv_map.insert(idx, e); + } + static void add_entry(environment const &, io_state const &, state & s, entry const & e) { buffer ts; switch (e.kind()) { case notation_entry_kind::NuD: to_buffer(e.get_transitions(), ts); + if (auto idx = get_head_index(ts.size(), ts.data(), e.get_expr())) + updt_inv_map(s, *idx, e); s.m_nud = s.m_nud.add(ts.size(), ts.data(), e.get_expr(), e.overload()); break; case notation_entry_kind::LeD: to_buffer(e.get_transitions(), ts); + if (auto idx = get_head_index(ts.size(), ts.data(), e.get_expr())) + updt_inv_map(s, *idx, e); s.m_led = s.m_led.add(ts.size(), ts.data(), e.get_expr(), e.overload()); break; case notation_entry_kind::Numeral: + if (!is_var(e.get_expr())) + updt_inv_map(s, head_index(e.get_expr()), e); if (!e.overload()) { s.m_num_map.insert(e.get_num(), list(e.get_expr())); } else if (auto it = s.m_num_map.find(e.get_num())) { @@ -308,6 +331,13 @@ list get_mpz_notation(environment const & env, mpz const & n) { } } +list get_notation_entries(environment const & env, head_index const & idx) { + if (auto it = notation_ext::get_state(env).m_inv_map.find(idx)) + return *it; + else + return list(); +} + environment overwrite_notation(environment const & env, name const & n) { environment r = env; bool found = false; diff --git a/src/frontends/lean/parser_config.h b/src/frontends/lean/parser_config.h index 67bdf85e6..1bb13db63 100644 --- a/src/frontends/lean/parser_config.h +++ b/src/frontends/lean/parser_config.h @@ -29,6 +29,7 @@ class notation_entry { }; expr m_expr; bool m_overload; + bool m_safe_ascii; public: notation_entry(); notation_entry(notation_entry const & e); @@ -43,7 +44,12 @@ public: mpz const & get_num() const { lean_assert(is_numeral()); return m_num; } expr const & get_expr() const { return m_expr; } bool overload() const { return m_overload; } + bool is_safe_ascii() const { return m_safe_ascii; } }; +bool operator==(notation_entry const & e1, notation_entry const & e2); +inline bool operator!=(notation_entry const & e1, notation_entry const & e2) { + return !(e1 == e2); +} /** \brief Apply \c f to expressions embedded in the notation entry */ notation_entry replace(notation_entry const & e, std::function const & f); @@ -75,6 +81,13 @@ environment add_mpz_notation(environment const & env, mpz const & n, expr const */ list get_mpz_notation(environment const & env, mpz const & n); +/** \brief Return the notation declaration that start with a given head symbol. + + \remark Notation declarations that contain C++ and Lua actions are not indexed. + Thus, they are to included in the result. +*/ +list get_notation_entries(environment const & env, head_index const & idx); + void initialize_parser_config(); void finalize_parser_config(); }