feat(frontends/lean/parse_config): add get_notation_entries auxiliary function for returning the list of notation declarations that start with a given head symbol
This API is need to take notation declarations into account when pretty printing expressions.
This commit is contained in:
parent
f17e67efcb
commit
eb79af98ba
4 changed files with 53 additions and 8 deletions
|
@ -359,10 +359,6 @@ static bool is_simple(unsigned num, transition const * ts) {
|
|||
return std::all_of(ts, ts+num, [](transition const & t) { return t.is_simple(); });
|
||||
}
|
||||
|
||||
bool is_safe_ascii(unsigned num, transition const * ts) {
|
||||
return std::all_of(ts, ts+num, [](transition const & t) { return t.is_safe_ascii(); });
|
||||
}
|
||||
|
||||
/** \brief Given \c a, an expression that is the denotation of an expression, if \c a is a variable,
|
||||
then use the actions in the transitions \c ts to expand \c a. The idea is to produce a head symbol
|
||||
we can use to decide whether the notation should be considered during pretty printing.
|
||||
|
|
|
@ -83,6 +83,8 @@ public:
|
|||
/** \brief Return true iff the action is not Ext or LuaExt */
|
||||
bool is_simple() const;
|
||||
};
|
||||
inline bool operator==(action const & a1, action const & a2) { return a1.is_equal(a2); }
|
||||
inline bool operator!=(action const & a1, action const & a2) { return !a1.is_equal(a2); }
|
||||
|
||||
action mk_skip_action();
|
||||
action mk_expr_action(unsigned rbp = 0);
|
||||
|
@ -108,8 +110,12 @@ public:
|
|||
bool is_simple() const { return m_action.is_simple(); }
|
||||
bool is_safe_ascii() const { return m_token.is_safe_ascii(); }
|
||||
};
|
||||
|
||||
bool is_safe_ascii(unsigned num, transition const * ts);
|
||||
inline bool operator==(transition const & t1, transition const & t2) {
|
||||
return t1.get_token() == t2.get_token() && t1.get_action() == t2.get_action();
|
||||
}
|
||||
inline bool operator!=(transition const & t1, transition const & t2) {
|
||||
return !(t1 == t2);
|
||||
}
|
||||
|
||||
/** \brief Apply \c f to expressions embedded in the given transition */
|
||||
transition replace(transition const & t, std::function<expr(expr const &)> const & f);
|
||||
|
|
|
@ -30,7 +30,7 @@ notation_entry replace(notation_entry const & e, std::function<expr(expr const &
|
|||
|
||||
notation_entry::notation_entry():m_kind(notation_entry_kind::NuD) {}
|
||||
notation_entry::notation_entry(notation_entry const & e):
|
||||
m_kind(e.m_kind), m_expr(e.m_expr), m_overload(e.m_overload) {
|
||||
m_kind(e.m_kind), m_expr(e.m_expr), m_overload(e.m_overload), m_safe_ascii(e.m_safe_ascii) {
|
||||
if (is_numeral())
|
||||
new (&m_num) mpz(e.m_num);
|
||||
else
|
||||
|
@ -41,13 +41,14 @@ notation_entry::notation_entry(bool is_nud, list<transition> const & ts, expr co
|
|||
m_kind(is_nud ? notation_entry_kind::NuD : notation_entry_kind::LeD),
|
||||
m_expr(e), m_overload(overload) {
|
||||
new (&m_transitions) list<transition>(ts);
|
||||
m_safe_ascii = std::all_of(ts.begin(), ts.end(), [](transition const & t) { return t.is_safe_ascii(); });
|
||||
}
|
||||
notation_entry::notation_entry(notation_entry const & e, bool overload):
|
||||
notation_entry(e) {
|
||||
m_overload = overload;
|
||||
}
|
||||
notation_entry::notation_entry(mpz const & val, expr const & e, bool overload):
|
||||
m_kind(notation_entry_kind::Numeral), m_expr(e), m_overload(overload) {
|
||||
m_kind(notation_entry_kind::Numeral), m_expr(e), m_overload(overload), m_safe_ascii(true) {
|
||||
new (&m_num) mpz(val);
|
||||
}
|
||||
|
||||
|
@ -58,6 +59,15 @@ notation_entry::~notation_entry() {
|
|||
m_transitions.~list<transition>();
|
||||
}
|
||||
|
||||
bool operator==(notation_entry const & e1, notation_entry const & e2) {
|
||||
if (e1.kind() != e2.kind() || e1.overload() != e2.overload() || e1.get_expr() != e2.get_expr())
|
||||
return false;
|
||||
if (e1.is_numeral())
|
||||
return e1.get_num() == e2.get_num();
|
||||
else
|
||||
return e1.get_transitions() == e2.get_transitions();
|
||||
}
|
||||
|
||||
struct token_state {
|
||||
token_table m_table;
|
||||
token_state() { m_table = mk_default_token_table(); }
|
||||
|
@ -183,9 +193,12 @@ transition read_transition(deserializer & d) {
|
|||
|
||||
struct notation_state {
|
||||
typedef rb_map<mpz, list<expr>, mpz_cmp_fn> num_map;
|
||||
typedef head_map<notation_entry> head_to_entries;
|
||||
parse_table m_nud;
|
||||
parse_table m_led;
|
||||
num_map m_num_map;
|
||||
head_to_entries m_inv_map;
|
||||
|
||||
notation_state() {
|
||||
m_nud = get_builtin_nud_table();
|
||||
m_led = get_builtin_led_table();
|
||||
|
@ -198,18 +211,28 @@ struct notation_config {
|
|||
static name * g_class_name;
|
||||
static std::string * g_key;
|
||||
|
||||
static void updt_inv_map(state & s, head_index const & idx, entry const & e) {
|
||||
s.m_inv_map.insert(idx, e);
|
||||
}
|
||||
|
||||
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
|
||||
buffer<transition> ts;
|
||||
switch (e.kind()) {
|
||||
case notation_entry_kind::NuD:
|
||||
to_buffer(e.get_transitions(), ts);
|
||||
if (auto idx = get_head_index(ts.size(), ts.data(), e.get_expr()))
|
||||
updt_inv_map(s, *idx, e);
|
||||
s.m_nud = s.m_nud.add(ts.size(), ts.data(), e.get_expr(), e.overload());
|
||||
break;
|
||||
case notation_entry_kind::LeD:
|
||||
to_buffer(e.get_transitions(), ts);
|
||||
if (auto idx = get_head_index(ts.size(), ts.data(), e.get_expr()))
|
||||
updt_inv_map(s, *idx, e);
|
||||
s.m_led = s.m_led.add(ts.size(), ts.data(), e.get_expr(), e.overload());
|
||||
break;
|
||||
case notation_entry_kind::Numeral:
|
||||
if (!is_var(e.get_expr()))
|
||||
updt_inv_map(s, head_index(e.get_expr()), e);
|
||||
if (!e.overload()) {
|
||||
s.m_num_map.insert(e.get_num(), list<expr>(e.get_expr()));
|
||||
} else if (auto it = s.m_num_map.find(e.get_num())) {
|
||||
|
@ -308,6 +331,13 @@ list<expr> get_mpz_notation(environment const & env, mpz const & n) {
|
|||
}
|
||||
}
|
||||
|
||||
list<notation_entry> get_notation_entries(environment const & env, head_index const & idx) {
|
||||
if (auto it = notation_ext::get_state(env).m_inv_map.find(idx))
|
||||
return *it;
|
||||
else
|
||||
return list<notation_entry>();
|
||||
}
|
||||
|
||||
environment overwrite_notation(environment const & env, name const & n) {
|
||||
environment r = env;
|
||||
bool found = false;
|
||||
|
|
|
@ -29,6 +29,7 @@ class notation_entry {
|
|||
};
|
||||
expr m_expr;
|
||||
bool m_overload;
|
||||
bool m_safe_ascii;
|
||||
public:
|
||||
notation_entry();
|
||||
notation_entry(notation_entry const & e);
|
||||
|
@ -43,7 +44,12 @@ public:
|
|||
mpz const & get_num() const { lean_assert(is_numeral()); return m_num; }
|
||||
expr const & get_expr() const { return m_expr; }
|
||||
bool overload() const { return m_overload; }
|
||||
bool is_safe_ascii() const { return m_safe_ascii; }
|
||||
};
|
||||
bool operator==(notation_entry const & e1, notation_entry const & e2);
|
||||
inline bool operator!=(notation_entry const & e1, notation_entry const & e2) {
|
||||
return !(e1 == e2);
|
||||
}
|
||||
|
||||
/** \brief Apply \c f to expressions embedded in the notation entry */
|
||||
notation_entry replace(notation_entry const & e, std::function<expr(expr const &)> const & f);
|
||||
|
@ -75,6 +81,13 @@ environment add_mpz_notation(environment const & env, mpz const & n, expr const
|
|||
*/
|
||||
list<expr> get_mpz_notation(environment const & env, mpz const & n);
|
||||
|
||||
/** \brief Return the notation declaration that start with a given head symbol.
|
||||
|
||||
\remark Notation declarations that contain C++ and Lua actions are not indexed.
|
||||
Thus, they are to included in the result.
|
||||
*/
|
||||
list<notation_entry> get_notation_entries(environment const & env, head_index const & idx);
|
||||
|
||||
void initialize_parser_config();
|
||||
void finalize_parser_config();
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue