feat(frontends/lean): allow users to define "numeral notation"
This commit is contained in:
parent
9928390605
commit
a3e38dc8a0
11 changed files with 218 additions and 31 deletions
|
@ -371,15 +371,28 @@ notation_entry parse_notation_core(parser & p, bool overload, buffer<token_entry
|
|||
return notation_entry(is_nud, to_list(ts.begin(), ts.end()), n, overload);
|
||||
}
|
||||
|
||||
environment parse_num_notation(parser & p, bool overload) {
|
||||
lean_assert(p.curr_is_numeral());
|
||||
mpz num = p.get_num_val().get_numerator();
|
||||
p.next();
|
||||
p.check_token_next(get_assign_tk(), "invalid numeral notation, `:=` expected");
|
||||
expr e = cleanup_section_notation(p, p.parse_expr());
|
||||
return add_mpz_notation(p.env(), num, e, overload);
|
||||
}
|
||||
|
||||
environment notation_cmd_core(parser & p, bool overload) {
|
||||
environment env = p.env();
|
||||
buffer<token_entry> new_tokens;
|
||||
if (p.curr_is_numeral()) {
|
||||
return parse_num_notation(p, overload);
|
||||
} else {
|
||||
auto ne = parse_notation_core(p, overload, new_tokens);
|
||||
for (auto const & te : new_tokens)
|
||||
env = add_token(env, te);
|
||||
env = add_notation(env, ne);
|
||||
return env;
|
||||
}
|
||||
}
|
||||
|
||||
bool curr_is_notation_decl(parser & p) {
|
||||
return p.curr_is_token(get_infix_tk()) || p.curr_is_token(get_infixl_tk()) || p.curr_is_token(get_infixr_tk()) ||
|
||||
|
|
|
@ -1058,10 +1058,22 @@ expr parser::parse_numeral_expr() {
|
|||
next();
|
||||
if (!m_has_num)
|
||||
m_has_num = has_num_decls(m_env);
|
||||
if (!*m_has_num)
|
||||
list<expr> vals = get_mpz_notation(m_env, n);
|
||||
if (!*m_has_num && !vals) {
|
||||
throw parser_error("numeral cannot be encoded as expression, environment does not contain the type 'num' "
|
||||
"(solution: use 'import num')", p);
|
||||
return from_num(n);
|
||||
"nor notation was defined for the given numeral "
|
||||
"(solution: use 'import data.num', or define notation for the given numeral)", p);
|
||||
}
|
||||
buffer<expr> cs;
|
||||
for (expr const & c : vals)
|
||||
cs.push_back(copy_with_new_pos(c, p));
|
||||
if (*m_has_num)
|
||||
cs.push_back(save_pos(from_num(n), p));
|
||||
lean_assert(!cs.empty());
|
||||
if (cs.size() == 1)
|
||||
return cs[0];
|
||||
else
|
||||
return save_pos(mk_choice(cs.size(), cs.data()), p);
|
||||
}
|
||||
|
||||
expr parser::parse_decimal_expr() {
|
||||
|
|
|
@ -217,27 +217,6 @@ environment add_led_notation(environment const & env, std::initializer_list<nota
|
|||
return add_led_notation(env, ts.size(), ts.begin(), a, overload);
|
||||
}
|
||||
|
||||
environment overwrite_notation(environment const & env, name const & n) {
|
||||
environment r = env;
|
||||
bool found = false;
|
||||
if (auto it = token_ext::get_entries(r, n)) {
|
||||
found = true;
|
||||
for (token_entry e : *it) {
|
||||
r = add_token(r, e);
|
||||
}
|
||||
}
|
||||
if (auto it = notation_ext::get_entries(env, n)) {
|
||||
found = true;
|
||||
for (notation_entry e : *it) {
|
||||
e.m_overload = false;
|
||||
r = add_notation(r, e);
|
||||
}
|
||||
}
|
||||
if (!found)
|
||||
throw exception(sstream() << "unknown namespace '" << n << "'");
|
||||
return r;
|
||||
}
|
||||
|
||||
parse_table const & get_nud_table(environment const & env) {
|
||||
return notation_ext::get_state(env).m_nud;
|
||||
}
|
||||
|
@ -265,6 +244,101 @@ cmd_table const & get_cmd_table(environment const & env) {
|
|||
return get_extension(env).m_cmds;
|
||||
}
|
||||
|
||||
struct mpz_notation_entry {
|
||||
mpz m_num;
|
||||
expr m_expr;
|
||||
bool m_overload;
|
||||
mpz_notation_entry():m_overload(false) {}
|
||||
mpz_notation_entry(mpz const & n, expr const & e, bool o):m_num(n), m_expr(e), m_overload(o) {}
|
||||
};
|
||||
|
||||
struct mpz_notation_state {
|
||||
typedef rb_map<mpz, list<expr>, mpz_cmp_fn> map;
|
||||
map m_map;
|
||||
};
|
||||
|
||||
struct mpz_notation_config {
|
||||
typedef mpz_notation_state state;
|
||||
typedef mpz_notation_entry entry;
|
||||
static name * g_class_name;
|
||||
static std::string * g_key;
|
||||
|
||||
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
|
||||
if (!e.m_overload) {
|
||||
s.m_map.insert(e.m_num, list<expr>(e.m_expr));
|
||||
} else if (auto it = s.m_map.find(e.m_num)) {
|
||||
list<expr> new_exprs = cons(e.m_expr, filter(*it, [&](expr const & n) { return n != e.m_expr; }));
|
||||
s.m_map.insert(e.m_num, new_exprs);
|
||||
} else {
|
||||
s.m_map.insert(e.m_num, list<expr>(e.m_expr));
|
||||
}
|
||||
}
|
||||
static name const & get_class_name() {
|
||||
return *g_class_name;
|
||||
}
|
||||
static std::string const & get_serialization_key() {
|
||||
return *g_key;
|
||||
}
|
||||
static void write_entry(serializer & s, entry const & e) {
|
||||
s << e.m_num << e.m_expr << e.m_overload;
|
||||
}
|
||||
static entry read_entry(deserializer & d) {
|
||||
entry e;
|
||||
d >> e.m_num >> e.m_expr >> e.m_overload;
|
||||
return e;
|
||||
}
|
||||
};
|
||||
|
||||
name * mpz_notation_config::g_class_name = nullptr;
|
||||
std::string * mpz_notation_config::g_key = nullptr;
|
||||
|
||||
template class scoped_ext<mpz_notation_config>;
|
||||
typedef scoped_ext<mpz_notation_config> mpz_notation_ext;
|
||||
|
||||
environment add_mpz_notation(environment const & env, mpz_notation_entry const & e) {
|
||||
return mpz_notation_ext::add_entry(env, get_dummy_ios(), e);
|
||||
}
|
||||
|
||||
environment add_mpz_notation(environment const & env, mpz const & n, expr const & e, bool overload) {
|
||||
return add_mpz_notation(env, mpz_notation_entry(n, e, overload));
|
||||
}
|
||||
|
||||
list<expr> get_mpz_notation(environment const & env, mpz const & n) {
|
||||
if (auto it = mpz_notation_ext::get_state(env).m_map.find(n)) {
|
||||
return *it;
|
||||
} else {
|
||||
return list<expr>();
|
||||
}
|
||||
}
|
||||
|
||||
environment overwrite_notation(environment const & env, name const & n) {
|
||||
environment r = env;
|
||||
bool found = false;
|
||||
if (auto it = token_ext::get_entries(r, n)) {
|
||||
found = true;
|
||||
for (token_entry e : *it) {
|
||||
r = add_token(r, e);
|
||||
}
|
||||
}
|
||||
if (auto it = notation_ext::get_entries(env, n)) {
|
||||
found = true;
|
||||
for (notation_entry e : *it) {
|
||||
e.m_overload = false;
|
||||
r = add_notation(r, e);
|
||||
}
|
||||
}
|
||||
if (auto it = mpz_notation_ext::get_entries(env, n)) {
|
||||
found = true;
|
||||
for (mpz_notation_entry e : *it) {
|
||||
e.m_overload = false;
|
||||
r = add_mpz_notation(r, e);
|
||||
}
|
||||
}
|
||||
if (!found)
|
||||
throw exception(sstream() << "unknown namespace '" << n << "'");
|
||||
return r;
|
||||
}
|
||||
|
||||
void initialize_parser_config() {
|
||||
token_config::g_class_name = new name("notation");
|
||||
token_config::g_key = new std::string("tk");
|
||||
|
@ -273,8 +347,14 @@ void initialize_parser_config() {
|
|||
notation_config::g_key = new std::string("nota");
|
||||
notation_ext::initialize();
|
||||
g_ext = new cmd_ext_reg();
|
||||
mpz_notation_config::g_class_name = new name("notation");
|
||||
mpz_notation_config::g_key = new std::string("numnota");
|
||||
mpz_notation_ext::initialize();
|
||||
}
|
||||
void finalize_parser_config() {
|
||||
mpz_notation_ext::finalize();
|
||||
delete mpz_notation_config::g_key;
|
||||
delete mpz_notation_config::g_class_name;
|
||||
delete g_ext;
|
||||
notation_ext::finalize();
|
||||
delete notation_config::g_key;
|
||||
|
|
|
@ -36,8 +36,10 @@ environment add_token(environment const & env, token_entry const & e);
|
|||
environment add_notation(environment const & env, notation_entry const & e);
|
||||
|
||||
environment add_token(environment const & env, char const * val, unsigned prec);
|
||||
environment add_nud_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a, bool overload = true);
|
||||
environment add_led_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a, bool overload = true);
|
||||
environment add_nud_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a,
|
||||
bool overload = true);
|
||||
environment add_led_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a,
|
||||
bool overload = true);
|
||||
environment add_nud_notation(environment const & env, std::initializer_list<notation::transition> const & ts, expr const & a,
|
||||
bool overload = true);
|
||||
environment add_led_notation(environment const & env, std::initializer_list<notation::transition> const & ts, expr const & a,
|
||||
|
@ -49,6 +51,14 @@ cmd_table const & get_cmd_table(environment const & env);
|
|||
/** \brief Force notation from namespace \c n to shadow any existing notation */
|
||||
environment overwrite_notation(environment const & env, name const & n);
|
||||
|
||||
/** \brief Add \c n as notation for \c e */
|
||||
environment add_mpz_notation(environment const & env, mpz const & n, expr const & e, bool overload = true);
|
||||
/** \brief Return the additional interpretations for \c n in the current environment.
|
||||
|
||||
\remark It does not include the default one based on the \c num inductive datatype.
|
||||
*/
|
||||
list<expr> get_mpz_notation(environment const & env, mpz const & n);
|
||||
|
||||
void initialize_parser_config();
|
||||
void finalize_parser_config();
|
||||
}
|
||||
|
|
|
@ -222,6 +222,10 @@ public:
|
|||
friend std::ostream & operator<<(std::ostream & out, mpz const & v);
|
||||
};
|
||||
|
||||
struct mpz_cmp_fn {
|
||||
int operator()(mpz const & v1, mpz const & v2) const { return cmp(v1, v2); }
|
||||
};
|
||||
|
||||
template<>
|
||||
class numeric_traits<mpz> {
|
||||
public:
|
||||
|
|
24
tests/lean/num2.lean
Normal file
24
tests/lean/num2.lean
Normal file
|
@ -0,0 +1,24 @@
|
|||
set_option pp.notation false
|
||||
definition Prop := Type.{0}
|
||||
variable eq {A : Type} : A → A → Prop
|
||||
infixl `=`:50 := eq
|
||||
|
||||
variable N : Type.{1}
|
||||
variable z : N
|
||||
variable o : N
|
||||
variable b : N
|
||||
|
||||
notation 0 := z
|
||||
notation 1 := o
|
||||
|
||||
check 1
|
||||
check 0
|
||||
|
||||
variable G : Type.{1}
|
||||
variable gz : G
|
||||
variable a : G
|
||||
|
||||
notation 0 := gz
|
||||
|
||||
check 0 = a
|
||||
check b = 0
|
4
tests/lean/num2.lean.expected.out
Normal file
4
tests/lean/num2.lean.expected.out
Normal file
|
@ -0,0 +1,4 @@
|
|||
o : N
|
||||
z : N
|
||||
eq gz a : Prop
|
||||
eq b z : Prop
|
14
tests/lean/num3.lean
Normal file
14
tests/lean/num3.lean
Normal file
|
@ -0,0 +1,14 @@
|
|||
import data.num
|
||||
set_option pp.notation false
|
||||
set_option pp.implicit true
|
||||
|
||||
variable N : Type.{1}
|
||||
variable z : N
|
||||
variable o : N
|
||||
variable a : N
|
||||
|
||||
notation 0 := z
|
||||
notation 1 := o
|
||||
|
||||
check a = 0
|
||||
check 2 = 1
|
2
tests/lean/num3.lean.expected.out
Normal file
2
tests/lean/num3.lean.expected.out
Normal file
|
@ -0,0 +1,2 @@
|
|||
@eq N a z : Prop
|
||||
@eq num 2 1 : Prop
|
20
tests/lean/num4.lean
Normal file
20
tests/lean/num4.lean
Normal file
|
@ -0,0 +1,20 @@
|
|||
import data.num
|
||||
set_option pp.notation false
|
||||
set_option pp.implicit true
|
||||
|
||||
namespace foo
|
||||
variable N : Type.{1}
|
||||
variable z : N
|
||||
variable o : N
|
||||
variable a : N
|
||||
notation 0 := z
|
||||
notation 1 := o
|
||||
|
||||
check a = 0
|
||||
end foo
|
||||
|
||||
check 2 = 1
|
||||
check #foo foo.a = 1
|
||||
|
||||
open foo
|
||||
check a = 1
|
4
tests/lean/num4.lean.expected.out
Normal file
4
tests/lean/num4.lean.expected.out
Normal file
|
@ -0,0 +1,4 @@
|
|||
@eq N a z : Prop
|
||||
@eq num 2 1 : Prop
|
||||
@eq foo.N foo.a foo.o : Prop
|
||||
@eq N a o : Prop
|
Loading…
Reference in a new issue