feat(frontends/lean/parse_config): add get_notation_entries auxiliary function for returning the list of notation declarations that start with a given head symbol

This API is need to take notation declarations into account when pretty
printing expressions.
This commit is contained in:
Leonardo de Moura 2014-10-18 11:49:27 -07:00
parent f17e67efcb
commit eb79af98ba
4 changed files with 53 additions and 8 deletions

View file

@ -359,10 +359,6 @@ static bool is_simple(unsigned num, transition const * ts) {
return std::all_of(ts, ts+num, [](transition const & t) { return t.is_simple(); });
}
bool is_safe_ascii(unsigned num, transition const * ts) {
return std::all_of(ts, ts+num, [](transition const & t) { return t.is_safe_ascii(); });
}
/** \brief Given \c a, an expression that is the denotation of an expression, if \c a is a variable,
then use the actions in the transitions \c ts to expand \c a. The idea is to produce a head symbol
we can use to decide whether the notation should be considered during pretty printing.

View file

@ -83,6 +83,8 @@ public:
/** \brief Return true iff the action is not Ext or LuaExt */
bool is_simple() const;
};
inline bool operator==(action const & a1, action const & a2) { return a1.is_equal(a2); }
inline bool operator!=(action const & a1, action const & a2) { return !a1.is_equal(a2); }
action mk_skip_action();
action mk_expr_action(unsigned rbp = 0);
@ -108,8 +110,12 @@ public:
bool is_simple() const { return m_action.is_simple(); }
bool is_safe_ascii() const { return m_token.is_safe_ascii(); }
};
bool is_safe_ascii(unsigned num, transition const * ts);
inline bool operator==(transition const & t1, transition const & t2) {
return t1.get_token() == t2.get_token() && t1.get_action() == t2.get_action();
}
inline bool operator!=(transition const & t1, transition const & t2) {
return !(t1 == t2);
}
/** \brief Apply \c f to expressions embedded in the given transition */
transition replace(transition const & t, std::function<expr(expr const &)> const & f);

View file

@ -30,7 +30,7 @@ notation_entry replace(notation_entry const & e, std::function<expr(expr const &
notation_entry::notation_entry():m_kind(notation_entry_kind::NuD) {}
notation_entry::notation_entry(notation_entry const & e):
m_kind(e.m_kind), m_expr(e.m_expr), m_overload(e.m_overload) {
m_kind(e.m_kind), m_expr(e.m_expr), m_overload(e.m_overload), m_safe_ascii(e.m_safe_ascii) {
if (is_numeral())
new (&m_num) mpz(e.m_num);
else
@ -41,13 +41,14 @@ notation_entry::notation_entry(bool is_nud, list<transition> const & ts, expr co
m_kind(is_nud ? notation_entry_kind::NuD : notation_entry_kind::LeD),
m_expr(e), m_overload(overload) {
new (&m_transitions) list<transition>(ts);
m_safe_ascii = std::all_of(ts.begin(), ts.end(), [](transition const & t) { return t.is_safe_ascii(); });
}
notation_entry::notation_entry(notation_entry const & e, bool overload):
notation_entry(e) {
m_overload = overload;
}
notation_entry::notation_entry(mpz const & val, expr const & e, bool overload):
m_kind(notation_entry_kind::Numeral), m_expr(e), m_overload(overload) {
m_kind(notation_entry_kind::Numeral), m_expr(e), m_overload(overload), m_safe_ascii(true) {
new (&m_num) mpz(val);
}
@ -58,6 +59,15 @@ notation_entry::~notation_entry() {
m_transitions.~list<transition>();
}
bool operator==(notation_entry const & e1, notation_entry const & e2) {
if (e1.kind() != e2.kind() || e1.overload() != e2.overload() || e1.get_expr() != e2.get_expr())
return false;
if (e1.is_numeral())
return e1.get_num() == e2.get_num();
else
return e1.get_transitions() == e2.get_transitions();
}
struct token_state {
token_table m_table;
token_state() { m_table = mk_default_token_table(); }
@ -183,9 +193,12 @@ transition read_transition(deserializer & d) {
struct notation_state {
typedef rb_map<mpz, list<expr>, mpz_cmp_fn> num_map;
typedef head_map<notation_entry> head_to_entries;
parse_table m_nud;
parse_table m_led;
num_map m_num_map;
head_to_entries m_inv_map;
notation_state() {
m_nud = get_builtin_nud_table();
m_led = get_builtin_led_table();
@ -198,18 +211,28 @@ struct notation_config {
static name * g_class_name;
static std::string * g_key;
static void updt_inv_map(state & s, head_index const & idx, entry const & e) {
s.m_inv_map.insert(idx, e);
}
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
buffer<transition> ts;
switch (e.kind()) {
case notation_entry_kind::NuD:
to_buffer(e.get_transitions(), ts);
if (auto idx = get_head_index(ts.size(), ts.data(), e.get_expr()))
updt_inv_map(s, *idx, e);
s.m_nud = s.m_nud.add(ts.size(), ts.data(), e.get_expr(), e.overload());
break;
case notation_entry_kind::LeD:
to_buffer(e.get_transitions(), ts);
if (auto idx = get_head_index(ts.size(), ts.data(), e.get_expr()))
updt_inv_map(s, *idx, e);
s.m_led = s.m_led.add(ts.size(), ts.data(), e.get_expr(), e.overload());
break;
case notation_entry_kind::Numeral:
if (!is_var(e.get_expr()))
updt_inv_map(s, head_index(e.get_expr()), e);
if (!e.overload()) {
s.m_num_map.insert(e.get_num(), list<expr>(e.get_expr()));
} else if (auto it = s.m_num_map.find(e.get_num())) {
@ -308,6 +331,13 @@ list<expr> get_mpz_notation(environment const & env, mpz const & n) {
}
}
list<notation_entry> get_notation_entries(environment const & env, head_index const & idx) {
if (auto it = notation_ext::get_state(env).m_inv_map.find(idx))
return *it;
else
return list<notation_entry>();
}
environment overwrite_notation(environment const & env, name const & n) {
environment r = env;
bool found = false;

View file

@ -29,6 +29,7 @@ class notation_entry {
};
expr m_expr;
bool m_overload;
bool m_safe_ascii;
public:
notation_entry();
notation_entry(notation_entry const & e);
@ -43,7 +44,12 @@ public:
mpz const & get_num() const { lean_assert(is_numeral()); return m_num; }
expr const & get_expr() const { return m_expr; }
bool overload() const { return m_overload; }
bool is_safe_ascii() const { return m_safe_ascii; }
};
bool operator==(notation_entry const & e1, notation_entry const & e2);
inline bool operator!=(notation_entry const & e1, notation_entry const & e2) {
return !(e1 == e2);
}
/** \brief Apply \c f to expressions embedded in the notation entry */
notation_entry replace(notation_entry const & e, std::function<expr(expr const &)> const & f);
@ -75,6 +81,13 @@ environment add_mpz_notation(environment const & env, mpz const & n, expr const
*/
list<expr> get_mpz_notation(environment const & env, mpz const & n);
/** \brief Return the notation declaration that start with a given head symbol.
\remark Notation declarations that contain C++ and Lua actions are not indexed.
Thus, they are to included in the result.
*/
list<notation_entry> get_notation_entries(environment const & env, head_index const & idx);
void initialize_parser_config();
void finalize_parser_config();
}