feat(frontends/lean/parse_config): add get_notation_entries auxiliary function for returning the list of notation declarations that start with a given head symbol

This API is need to take notation declarations into account when pretty
printing expressions.
This commit is contained in:
Leonardo de Moura 2014-10-18 11:49:27 -07:00
parent f17e67efcb
commit eb79af98ba
4 changed files with 53 additions and 8 deletions

View file

@ -359,10 +359,6 @@ static bool is_simple(unsigned num, transition const * ts) {
return std::all_of(ts, ts+num, [](transition const & t) { return t.is_simple(); }); return std::all_of(ts, ts+num, [](transition const & t) { return t.is_simple(); });
} }
bool is_safe_ascii(unsigned num, transition const * ts) {
return std::all_of(ts, ts+num, [](transition const & t) { return t.is_safe_ascii(); });
}
/** \brief Given \c a, an expression that is the denotation of an expression, if \c a is a variable, /** \brief Given \c a, an expression that is the denotation of an expression, if \c a is a variable,
then use the actions in the transitions \c ts to expand \c a. The idea is to produce a head symbol then use the actions in the transitions \c ts to expand \c a. The idea is to produce a head symbol
we can use to decide whether the notation should be considered during pretty printing. we can use to decide whether the notation should be considered during pretty printing.

View file

@ -83,6 +83,8 @@ public:
/** \brief Return true iff the action is not Ext or LuaExt */ /** \brief Return true iff the action is not Ext or LuaExt */
bool is_simple() const; bool is_simple() const;
}; };
inline bool operator==(action const & a1, action const & a2) { return a1.is_equal(a2); }
inline bool operator!=(action const & a1, action const & a2) { return !a1.is_equal(a2); }
action mk_skip_action(); action mk_skip_action();
action mk_expr_action(unsigned rbp = 0); action mk_expr_action(unsigned rbp = 0);
@ -108,8 +110,12 @@ public:
bool is_simple() const { return m_action.is_simple(); } bool is_simple() const { return m_action.is_simple(); }
bool is_safe_ascii() const { return m_token.is_safe_ascii(); } bool is_safe_ascii() const { return m_token.is_safe_ascii(); }
}; };
inline bool operator==(transition const & t1, transition const & t2) {
bool is_safe_ascii(unsigned num, transition const * ts); return t1.get_token() == t2.get_token() && t1.get_action() == t2.get_action();
}
inline bool operator!=(transition const & t1, transition const & t2) {
return !(t1 == t2);
}
/** \brief Apply \c f to expressions embedded in the given transition */ /** \brief Apply \c f to expressions embedded in the given transition */
transition replace(transition const & t, std::function<expr(expr const &)> const & f); transition replace(transition const & t, std::function<expr(expr const &)> const & f);

View file

@ -30,7 +30,7 @@ notation_entry replace(notation_entry const & e, std::function<expr(expr const &
notation_entry::notation_entry():m_kind(notation_entry_kind::NuD) {} notation_entry::notation_entry():m_kind(notation_entry_kind::NuD) {}
notation_entry::notation_entry(notation_entry const & e): notation_entry::notation_entry(notation_entry const & e):
m_kind(e.m_kind), m_expr(e.m_expr), m_overload(e.m_overload) { m_kind(e.m_kind), m_expr(e.m_expr), m_overload(e.m_overload), m_safe_ascii(e.m_safe_ascii) {
if (is_numeral()) if (is_numeral())
new (&m_num) mpz(e.m_num); new (&m_num) mpz(e.m_num);
else else
@ -41,13 +41,14 @@ notation_entry::notation_entry(bool is_nud, list<transition> const & ts, expr co
m_kind(is_nud ? notation_entry_kind::NuD : notation_entry_kind::LeD), m_kind(is_nud ? notation_entry_kind::NuD : notation_entry_kind::LeD),
m_expr(e), m_overload(overload) { m_expr(e), m_overload(overload) {
new (&m_transitions) list<transition>(ts); new (&m_transitions) list<transition>(ts);
m_safe_ascii = std::all_of(ts.begin(), ts.end(), [](transition const & t) { return t.is_safe_ascii(); });
} }
notation_entry::notation_entry(notation_entry const & e, bool overload): notation_entry::notation_entry(notation_entry const & e, bool overload):
notation_entry(e) { notation_entry(e) {
m_overload = overload; m_overload = overload;
} }
notation_entry::notation_entry(mpz const & val, expr const & e, bool overload): notation_entry::notation_entry(mpz const & val, expr const & e, bool overload):
m_kind(notation_entry_kind::Numeral), m_expr(e), m_overload(overload) { m_kind(notation_entry_kind::Numeral), m_expr(e), m_overload(overload), m_safe_ascii(true) {
new (&m_num) mpz(val); new (&m_num) mpz(val);
} }
@ -58,6 +59,15 @@ notation_entry::~notation_entry() {
m_transitions.~list<transition>(); m_transitions.~list<transition>();
} }
bool operator==(notation_entry const & e1, notation_entry const & e2) {
if (e1.kind() != e2.kind() || e1.overload() != e2.overload() || e1.get_expr() != e2.get_expr())
return false;
if (e1.is_numeral())
return e1.get_num() == e2.get_num();
else
return e1.get_transitions() == e2.get_transitions();
}
struct token_state { struct token_state {
token_table m_table; token_table m_table;
token_state() { m_table = mk_default_token_table(); } token_state() { m_table = mk_default_token_table(); }
@ -183,9 +193,12 @@ transition read_transition(deserializer & d) {
struct notation_state { struct notation_state {
typedef rb_map<mpz, list<expr>, mpz_cmp_fn> num_map; typedef rb_map<mpz, list<expr>, mpz_cmp_fn> num_map;
typedef head_map<notation_entry> head_to_entries;
parse_table m_nud; parse_table m_nud;
parse_table m_led; parse_table m_led;
num_map m_num_map; num_map m_num_map;
head_to_entries m_inv_map;
notation_state() { notation_state() {
m_nud = get_builtin_nud_table(); m_nud = get_builtin_nud_table();
m_led = get_builtin_led_table(); m_led = get_builtin_led_table();
@ -198,18 +211,28 @@ struct notation_config {
static name * g_class_name; static name * g_class_name;
static std::string * g_key; static std::string * g_key;
static void updt_inv_map(state & s, head_index const & idx, entry const & e) {
s.m_inv_map.insert(idx, e);
}
static void add_entry(environment const &, io_state const &, state & s, entry const & e) { static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
buffer<transition> ts; buffer<transition> ts;
switch (e.kind()) { switch (e.kind()) {
case notation_entry_kind::NuD: case notation_entry_kind::NuD:
to_buffer(e.get_transitions(), ts); to_buffer(e.get_transitions(), ts);
if (auto idx = get_head_index(ts.size(), ts.data(), e.get_expr()))
updt_inv_map(s, *idx, e);
s.m_nud = s.m_nud.add(ts.size(), ts.data(), e.get_expr(), e.overload()); s.m_nud = s.m_nud.add(ts.size(), ts.data(), e.get_expr(), e.overload());
break; break;
case notation_entry_kind::LeD: case notation_entry_kind::LeD:
to_buffer(e.get_transitions(), ts); to_buffer(e.get_transitions(), ts);
if (auto idx = get_head_index(ts.size(), ts.data(), e.get_expr()))
updt_inv_map(s, *idx, e);
s.m_led = s.m_led.add(ts.size(), ts.data(), e.get_expr(), e.overload()); s.m_led = s.m_led.add(ts.size(), ts.data(), e.get_expr(), e.overload());
break; break;
case notation_entry_kind::Numeral: case notation_entry_kind::Numeral:
if (!is_var(e.get_expr()))
updt_inv_map(s, head_index(e.get_expr()), e);
if (!e.overload()) { if (!e.overload()) {
s.m_num_map.insert(e.get_num(), list<expr>(e.get_expr())); s.m_num_map.insert(e.get_num(), list<expr>(e.get_expr()));
} else if (auto it = s.m_num_map.find(e.get_num())) { } else if (auto it = s.m_num_map.find(e.get_num())) {
@ -308,6 +331,13 @@ list<expr> get_mpz_notation(environment const & env, mpz const & n) {
} }
} }
list<notation_entry> get_notation_entries(environment const & env, head_index const & idx) {
if (auto it = notation_ext::get_state(env).m_inv_map.find(idx))
return *it;
else
return list<notation_entry>();
}
environment overwrite_notation(environment const & env, name const & n) { environment overwrite_notation(environment const & env, name const & n) {
environment r = env; environment r = env;
bool found = false; bool found = false;

View file

@ -29,6 +29,7 @@ class notation_entry {
}; };
expr m_expr; expr m_expr;
bool m_overload; bool m_overload;
bool m_safe_ascii;
public: public:
notation_entry(); notation_entry();
notation_entry(notation_entry const & e); notation_entry(notation_entry const & e);
@ -43,7 +44,12 @@ public:
mpz const & get_num() const { lean_assert(is_numeral()); return m_num; } mpz const & get_num() const { lean_assert(is_numeral()); return m_num; }
expr const & get_expr() const { return m_expr; } expr const & get_expr() const { return m_expr; }
bool overload() const { return m_overload; } bool overload() const { return m_overload; }
bool is_safe_ascii() const { return m_safe_ascii; }
}; };
bool operator==(notation_entry const & e1, notation_entry const & e2);
inline bool operator!=(notation_entry const & e1, notation_entry const & e2) {
return !(e1 == e2);
}
/** \brief Apply \c f to expressions embedded in the notation entry */ /** \brief Apply \c f to expressions embedded in the notation entry */
notation_entry replace(notation_entry const & e, std::function<expr(expr const &)> const & f); notation_entry replace(notation_entry const & e, std::function<expr(expr const &)> const & f);
@ -75,6 +81,13 @@ environment add_mpz_notation(environment const & env, mpz const & n, expr const
*/ */
list<expr> get_mpz_notation(environment const & env, mpz const & n); list<expr> get_mpz_notation(environment const & env, mpz const & n);
/** \brief Return the notation declaration that start with a given head symbol.
\remark Notation declarations that contain C++ and Lua actions are not indexed.
Thus, they are to included in the result.
*/
list<notation_entry> get_notation_entries(environment const & env, head_index const & idx);
void initialize_parser_config(); void initialize_parser_config();
void finalize_parser_config(); void finalize_parser_config();
} }