feat(frontends/lean): allow users to define "numeral notation"
This commit is contained in:
parent
9928390605
commit
a3e38dc8a0
11 changed files with 218 additions and 31 deletions
|
@ -371,14 +371,27 @@ notation_entry parse_notation_core(parser & p, bool overload, buffer<token_entry
|
||||||
return notation_entry(is_nud, to_list(ts.begin(), ts.end()), n, overload);
|
return notation_entry(is_nud, to_list(ts.begin(), ts.end()), n, overload);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
environment parse_num_notation(parser & p, bool overload) {
|
||||||
|
lean_assert(p.curr_is_numeral());
|
||||||
|
mpz num = p.get_num_val().get_numerator();
|
||||||
|
p.next();
|
||||||
|
p.check_token_next(get_assign_tk(), "invalid numeral notation, `:=` expected");
|
||||||
|
expr e = cleanup_section_notation(p, p.parse_expr());
|
||||||
|
return add_mpz_notation(p.env(), num, e, overload);
|
||||||
|
}
|
||||||
|
|
||||||
environment notation_cmd_core(parser & p, bool overload) {
|
environment notation_cmd_core(parser & p, bool overload) {
|
||||||
environment env = p.env();
|
environment env = p.env();
|
||||||
buffer<token_entry> new_tokens;
|
buffer<token_entry> new_tokens;
|
||||||
auto ne = parse_notation_core(p, overload, new_tokens);
|
if (p.curr_is_numeral()) {
|
||||||
for (auto const & te : new_tokens)
|
return parse_num_notation(p, overload);
|
||||||
env = add_token(env, te);
|
} else {
|
||||||
env = add_notation(env, ne);
|
auto ne = parse_notation_core(p, overload, new_tokens);
|
||||||
return env;
|
for (auto const & te : new_tokens)
|
||||||
|
env = add_token(env, te);
|
||||||
|
env = add_notation(env, ne);
|
||||||
|
return env;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool curr_is_notation_decl(parser & p) {
|
bool curr_is_notation_decl(parser & p) {
|
||||||
|
|
|
@ -1058,10 +1058,22 @@ expr parser::parse_numeral_expr() {
|
||||||
next();
|
next();
|
||||||
if (!m_has_num)
|
if (!m_has_num)
|
||||||
m_has_num = has_num_decls(m_env);
|
m_has_num = has_num_decls(m_env);
|
||||||
if (!*m_has_num)
|
list<expr> vals = get_mpz_notation(m_env, n);
|
||||||
|
if (!*m_has_num && !vals) {
|
||||||
throw parser_error("numeral cannot be encoded as expression, environment does not contain the type 'num' "
|
throw parser_error("numeral cannot be encoded as expression, environment does not contain the type 'num' "
|
||||||
"(solution: use 'import num')", p);
|
"nor notation was defined for the given numeral "
|
||||||
return from_num(n);
|
"(solution: use 'import data.num', or define notation for the given numeral)", p);
|
||||||
|
}
|
||||||
|
buffer<expr> cs;
|
||||||
|
for (expr const & c : vals)
|
||||||
|
cs.push_back(copy_with_new_pos(c, p));
|
||||||
|
if (*m_has_num)
|
||||||
|
cs.push_back(save_pos(from_num(n), p));
|
||||||
|
lean_assert(!cs.empty());
|
||||||
|
if (cs.size() == 1)
|
||||||
|
return cs[0];
|
||||||
|
else
|
||||||
|
return save_pos(mk_choice(cs.size(), cs.data()), p);
|
||||||
}
|
}
|
||||||
|
|
||||||
expr parser::parse_decimal_expr() {
|
expr parser::parse_decimal_expr() {
|
||||||
|
|
|
@ -217,27 +217,6 @@ environment add_led_notation(environment const & env, std::initializer_list<nota
|
||||||
return add_led_notation(env, ts.size(), ts.begin(), a, overload);
|
return add_led_notation(env, ts.size(), ts.begin(), a, overload);
|
||||||
}
|
}
|
||||||
|
|
||||||
environment overwrite_notation(environment const & env, name const & n) {
|
|
||||||
environment r = env;
|
|
||||||
bool found = false;
|
|
||||||
if (auto it = token_ext::get_entries(r, n)) {
|
|
||||||
found = true;
|
|
||||||
for (token_entry e : *it) {
|
|
||||||
r = add_token(r, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (auto it = notation_ext::get_entries(env, n)) {
|
|
||||||
found = true;
|
|
||||||
for (notation_entry e : *it) {
|
|
||||||
e.m_overload = false;
|
|
||||||
r = add_notation(r, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!found)
|
|
||||||
throw exception(sstream() << "unknown namespace '" << n << "'");
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
parse_table const & get_nud_table(environment const & env) {
|
parse_table const & get_nud_table(environment const & env) {
|
||||||
return notation_ext::get_state(env).m_nud;
|
return notation_ext::get_state(env).m_nud;
|
||||||
}
|
}
|
||||||
|
@ -265,6 +244,101 @@ cmd_table const & get_cmd_table(environment const & env) {
|
||||||
return get_extension(env).m_cmds;
|
return get_extension(env).m_cmds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct mpz_notation_entry {
|
||||||
|
mpz m_num;
|
||||||
|
expr m_expr;
|
||||||
|
bool m_overload;
|
||||||
|
mpz_notation_entry():m_overload(false) {}
|
||||||
|
mpz_notation_entry(mpz const & n, expr const & e, bool o):m_num(n), m_expr(e), m_overload(o) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct mpz_notation_state {
|
||||||
|
typedef rb_map<mpz, list<expr>, mpz_cmp_fn> map;
|
||||||
|
map m_map;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct mpz_notation_config {
|
||||||
|
typedef mpz_notation_state state;
|
||||||
|
typedef mpz_notation_entry entry;
|
||||||
|
static name * g_class_name;
|
||||||
|
static std::string * g_key;
|
||||||
|
|
||||||
|
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
|
||||||
|
if (!e.m_overload) {
|
||||||
|
s.m_map.insert(e.m_num, list<expr>(e.m_expr));
|
||||||
|
} else if (auto it = s.m_map.find(e.m_num)) {
|
||||||
|
list<expr> new_exprs = cons(e.m_expr, filter(*it, [&](expr const & n) { return n != e.m_expr; }));
|
||||||
|
s.m_map.insert(e.m_num, new_exprs);
|
||||||
|
} else {
|
||||||
|
s.m_map.insert(e.m_num, list<expr>(e.m_expr));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
static name const & get_class_name() {
|
||||||
|
return *g_class_name;
|
||||||
|
}
|
||||||
|
static std::string const & get_serialization_key() {
|
||||||
|
return *g_key;
|
||||||
|
}
|
||||||
|
static void write_entry(serializer & s, entry const & e) {
|
||||||
|
s << e.m_num << e.m_expr << e.m_overload;
|
||||||
|
}
|
||||||
|
static entry read_entry(deserializer & d) {
|
||||||
|
entry e;
|
||||||
|
d >> e.m_num >> e.m_expr >> e.m_overload;
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
name * mpz_notation_config::g_class_name = nullptr;
|
||||||
|
std::string * mpz_notation_config::g_key = nullptr;
|
||||||
|
|
||||||
|
template class scoped_ext<mpz_notation_config>;
|
||||||
|
typedef scoped_ext<mpz_notation_config> mpz_notation_ext;
|
||||||
|
|
||||||
|
environment add_mpz_notation(environment const & env, mpz_notation_entry const & e) {
|
||||||
|
return mpz_notation_ext::add_entry(env, get_dummy_ios(), e);
|
||||||
|
}
|
||||||
|
|
||||||
|
environment add_mpz_notation(environment const & env, mpz const & n, expr const & e, bool overload) {
|
||||||
|
return add_mpz_notation(env, mpz_notation_entry(n, e, overload));
|
||||||
|
}
|
||||||
|
|
||||||
|
list<expr> get_mpz_notation(environment const & env, mpz const & n) {
|
||||||
|
if (auto it = mpz_notation_ext::get_state(env).m_map.find(n)) {
|
||||||
|
return *it;
|
||||||
|
} else {
|
||||||
|
return list<expr>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
environment overwrite_notation(environment const & env, name const & n) {
|
||||||
|
environment r = env;
|
||||||
|
bool found = false;
|
||||||
|
if (auto it = token_ext::get_entries(r, n)) {
|
||||||
|
found = true;
|
||||||
|
for (token_entry e : *it) {
|
||||||
|
r = add_token(r, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (auto it = notation_ext::get_entries(env, n)) {
|
||||||
|
found = true;
|
||||||
|
for (notation_entry e : *it) {
|
||||||
|
e.m_overload = false;
|
||||||
|
r = add_notation(r, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (auto it = mpz_notation_ext::get_entries(env, n)) {
|
||||||
|
found = true;
|
||||||
|
for (mpz_notation_entry e : *it) {
|
||||||
|
e.m_overload = false;
|
||||||
|
r = add_mpz_notation(r, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found)
|
||||||
|
throw exception(sstream() << "unknown namespace '" << n << "'");
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
void initialize_parser_config() {
|
void initialize_parser_config() {
|
||||||
token_config::g_class_name = new name("notation");
|
token_config::g_class_name = new name("notation");
|
||||||
token_config::g_key = new std::string("tk");
|
token_config::g_key = new std::string("tk");
|
||||||
|
@ -273,8 +347,14 @@ void initialize_parser_config() {
|
||||||
notation_config::g_key = new std::string("nota");
|
notation_config::g_key = new std::string("nota");
|
||||||
notation_ext::initialize();
|
notation_ext::initialize();
|
||||||
g_ext = new cmd_ext_reg();
|
g_ext = new cmd_ext_reg();
|
||||||
|
mpz_notation_config::g_class_name = new name("notation");
|
||||||
|
mpz_notation_config::g_key = new std::string("numnota");
|
||||||
|
mpz_notation_ext::initialize();
|
||||||
}
|
}
|
||||||
void finalize_parser_config() {
|
void finalize_parser_config() {
|
||||||
|
mpz_notation_ext::finalize();
|
||||||
|
delete mpz_notation_config::g_key;
|
||||||
|
delete mpz_notation_config::g_class_name;
|
||||||
delete g_ext;
|
delete g_ext;
|
||||||
notation_ext::finalize();
|
notation_ext::finalize();
|
||||||
delete notation_config::g_key;
|
delete notation_config::g_key;
|
||||||
|
|
|
@ -36,8 +36,10 @@ environment add_token(environment const & env, token_entry const & e);
|
||||||
environment add_notation(environment const & env, notation_entry const & e);
|
environment add_notation(environment const & env, notation_entry const & e);
|
||||||
|
|
||||||
environment add_token(environment const & env, char const * val, unsigned prec);
|
environment add_token(environment const & env, char const * val, unsigned prec);
|
||||||
environment add_nud_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a, bool overload = true);
|
environment add_nud_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a,
|
||||||
environment add_led_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a, bool overload = true);
|
bool overload = true);
|
||||||
|
environment add_led_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a,
|
||||||
|
bool overload = true);
|
||||||
environment add_nud_notation(environment const & env, std::initializer_list<notation::transition> const & ts, expr const & a,
|
environment add_nud_notation(environment const & env, std::initializer_list<notation::transition> const & ts, expr const & a,
|
||||||
bool overload = true);
|
bool overload = true);
|
||||||
environment add_led_notation(environment const & env, std::initializer_list<notation::transition> const & ts, expr const & a,
|
environment add_led_notation(environment const & env, std::initializer_list<notation::transition> const & ts, expr const & a,
|
||||||
|
@ -49,6 +51,14 @@ cmd_table const & get_cmd_table(environment const & env);
|
||||||
/** \brief Force notation from namespace \c n to shadow any existing notation */
|
/** \brief Force notation from namespace \c n to shadow any existing notation */
|
||||||
environment overwrite_notation(environment const & env, name const & n);
|
environment overwrite_notation(environment const & env, name const & n);
|
||||||
|
|
||||||
|
/** \brief Add \c n as notation for \c e */
|
||||||
|
environment add_mpz_notation(environment const & env, mpz const & n, expr const & e, bool overload = true);
|
||||||
|
/** \brief Return the additional interpretations for \c n in the current environment.
|
||||||
|
|
||||||
|
\remark It does not include the default one based on the \c num inductive datatype.
|
||||||
|
*/
|
||||||
|
list<expr> get_mpz_notation(environment const & env, mpz const & n);
|
||||||
|
|
||||||
void initialize_parser_config();
|
void initialize_parser_config();
|
||||||
void finalize_parser_config();
|
void finalize_parser_config();
|
||||||
}
|
}
|
||||||
|
|
|
@ -222,6 +222,10 @@ public:
|
||||||
friend std::ostream & operator<<(std::ostream & out, mpz const & v);
|
friend std::ostream & operator<<(std::ostream & out, mpz const & v);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct mpz_cmp_fn {
|
||||||
|
int operator()(mpz const & v1, mpz const & v2) const { return cmp(v1, v2); }
|
||||||
|
};
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
class numeric_traits<mpz> {
|
class numeric_traits<mpz> {
|
||||||
public:
|
public:
|
||||||
|
|
24
tests/lean/num2.lean
Normal file
24
tests/lean/num2.lean
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
set_option pp.notation false
|
||||||
|
definition Prop := Type.{0}
|
||||||
|
variable eq {A : Type} : A → A → Prop
|
||||||
|
infixl `=`:50 := eq
|
||||||
|
|
||||||
|
variable N : Type.{1}
|
||||||
|
variable z : N
|
||||||
|
variable o : N
|
||||||
|
variable b : N
|
||||||
|
|
||||||
|
notation 0 := z
|
||||||
|
notation 1 := o
|
||||||
|
|
||||||
|
check 1
|
||||||
|
check 0
|
||||||
|
|
||||||
|
variable G : Type.{1}
|
||||||
|
variable gz : G
|
||||||
|
variable a : G
|
||||||
|
|
||||||
|
notation 0 := gz
|
||||||
|
|
||||||
|
check 0 = a
|
||||||
|
check b = 0
|
4
tests/lean/num2.lean.expected.out
Normal file
4
tests/lean/num2.lean.expected.out
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
o : N
|
||||||
|
z : N
|
||||||
|
eq gz a : Prop
|
||||||
|
eq b z : Prop
|
14
tests/lean/num3.lean
Normal file
14
tests/lean/num3.lean
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
import data.num
|
||||||
|
set_option pp.notation false
|
||||||
|
set_option pp.implicit true
|
||||||
|
|
||||||
|
variable N : Type.{1}
|
||||||
|
variable z : N
|
||||||
|
variable o : N
|
||||||
|
variable a : N
|
||||||
|
|
||||||
|
notation 0 := z
|
||||||
|
notation 1 := o
|
||||||
|
|
||||||
|
check a = 0
|
||||||
|
check 2 = 1
|
2
tests/lean/num3.lean.expected.out
Normal file
2
tests/lean/num3.lean.expected.out
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
@eq N a z : Prop
|
||||||
|
@eq num 2 1 : Prop
|
20
tests/lean/num4.lean
Normal file
20
tests/lean/num4.lean
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
import data.num
|
||||||
|
set_option pp.notation false
|
||||||
|
set_option pp.implicit true
|
||||||
|
|
||||||
|
namespace foo
|
||||||
|
variable N : Type.{1}
|
||||||
|
variable z : N
|
||||||
|
variable o : N
|
||||||
|
variable a : N
|
||||||
|
notation 0 := z
|
||||||
|
notation 1 := o
|
||||||
|
|
||||||
|
check a = 0
|
||||||
|
end foo
|
||||||
|
|
||||||
|
check 2 = 1
|
||||||
|
check #foo foo.a = 1
|
||||||
|
|
||||||
|
open foo
|
||||||
|
check a = 1
|
4
tests/lean/num4.lean.expected.out
Normal file
4
tests/lean/num4.lean.expected.out
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
@eq N a z : Prop
|
||||||
|
@eq num 2 1 : Prop
|
||||||
|
@eq foo.N foo.a foo.o : Prop
|
||||||
|
@eq N a o : Prop
|
Loading…
Reference in a new issue