feat(frontends/lean): allow users to define "numeral notation"

This commit is contained in:
Leonardo de Moura 2014-09-26 14:54:39 -07:00
parent 9928390605
commit a3e38dc8a0
11 changed files with 218 additions and 31 deletions

View file

@ -371,14 +371,27 @@ notation_entry parse_notation_core(parser & p, bool overload, buffer<token_entry
return notation_entry(is_nud, to_list(ts.begin(), ts.end()), n, overload);
}
environment parse_num_notation(parser & p, bool overload) {
lean_assert(p.curr_is_numeral());
mpz num = p.get_num_val().get_numerator();
p.next();
p.check_token_next(get_assign_tk(), "invalid numeral notation, `:=` expected");
expr e = cleanup_section_notation(p, p.parse_expr());
return add_mpz_notation(p.env(), num, e, overload);
}
environment notation_cmd_core(parser & p, bool overload) {
environment env = p.env();
buffer<token_entry> new_tokens;
auto ne = parse_notation_core(p, overload, new_tokens);
for (auto const & te : new_tokens)
env = add_token(env, te);
env = add_notation(env, ne);
return env;
if (p.curr_is_numeral()) {
return parse_num_notation(p, overload);
} else {
auto ne = parse_notation_core(p, overload, new_tokens);
for (auto const & te : new_tokens)
env = add_token(env, te);
env = add_notation(env, ne);
return env;
}
}
bool curr_is_notation_decl(parser & p) {

View file

@ -1058,10 +1058,22 @@ expr parser::parse_numeral_expr() {
next();
if (!m_has_num)
m_has_num = has_num_decls(m_env);
if (!*m_has_num)
list<expr> vals = get_mpz_notation(m_env, n);
if (!*m_has_num && !vals) {
throw parser_error("numeral cannot be encoded as expression, environment does not contain the type 'num' "
"(solution: use 'import num')", p);
return from_num(n);
"nor notation was defined for the given numeral "
"(solution: use 'import data.num', or define notation for the given numeral)", p);
}
buffer<expr> cs;
for (expr const & c : vals)
cs.push_back(copy_with_new_pos(c, p));
if (*m_has_num)
cs.push_back(save_pos(from_num(n), p));
lean_assert(!cs.empty());
if (cs.size() == 1)
return cs[0];
else
return save_pos(mk_choice(cs.size(), cs.data()), p);
}
expr parser::parse_decimal_expr() {

View file

@ -217,27 +217,6 @@ environment add_led_notation(environment const & env, std::initializer_list<nota
return add_led_notation(env, ts.size(), ts.begin(), a, overload);
}
environment overwrite_notation(environment const & env, name const & n) {
environment r = env;
bool found = false;
if (auto it = token_ext::get_entries(r, n)) {
found = true;
for (token_entry e : *it) {
r = add_token(r, e);
}
}
if (auto it = notation_ext::get_entries(env, n)) {
found = true;
for (notation_entry e : *it) {
e.m_overload = false;
r = add_notation(r, e);
}
}
if (!found)
throw exception(sstream() << "unknown namespace '" << n << "'");
return r;
}
parse_table const & get_nud_table(environment const & env) {
return notation_ext::get_state(env).m_nud;
}
@ -265,6 +244,101 @@ cmd_table const & get_cmd_table(environment const & env) {
return get_extension(env).m_cmds;
}
struct mpz_notation_entry {
mpz m_num;
expr m_expr;
bool m_overload;
mpz_notation_entry():m_overload(false) {}
mpz_notation_entry(mpz const & n, expr const & e, bool o):m_num(n), m_expr(e), m_overload(o) {}
};
struct mpz_notation_state {
typedef rb_map<mpz, list<expr>, mpz_cmp_fn> map;
map m_map;
};
struct mpz_notation_config {
typedef mpz_notation_state state;
typedef mpz_notation_entry entry;
static name * g_class_name;
static std::string * g_key;
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
if (!e.m_overload) {
s.m_map.insert(e.m_num, list<expr>(e.m_expr));
} else if (auto it = s.m_map.find(e.m_num)) {
list<expr> new_exprs = cons(e.m_expr, filter(*it, [&](expr const & n) { return n != e.m_expr; }));
s.m_map.insert(e.m_num, new_exprs);
} else {
s.m_map.insert(e.m_num, list<expr>(e.m_expr));
}
}
static name const & get_class_name() {
return *g_class_name;
}
static std::string const & get_serialization_key() {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << e.m_num << e.m_expr << e.m_overload;
}
static entry read_entry(deserializer & d) {
entry e;
d >> e.m_num >> e.m_expr >> e.m_overload;
return e;
}
};
name * mpz_notation_config::g_class_name = nullptr;
std::string * mpz_notation_config::g_key = nullptr;
template class scoped_ext<mpz_notation_config>;
typedef scoped_ext<mpz_notation_config> mpz_notation_ext;
environment add_mpz_notation(environment const & env, mpz_notation_entry const & e) {
return mpz_notation_ext::add_entry(env, get_dummy_ios(), e);
}
environment add_mpz_notation(environment const & env, mpz const & n, expr const & e, bool overload) {
return add_mpz_notation(env, mpz_notation_entry(n, e, overload));
}
list<expr> get_mpz_notation(environment const & env, mpz const & n) {
if (auto it = mpz_notation_ext::get_state(env).m_map.find(n)) {
return *it;
} else {
return list<expr>();
}
}
environment overwrite_notation(environment const & env, name const & n) {
environment r = env;
bool found = false;
if (auto it = token_ext::get_entries(r, n)) {
found = true;
for (token_entry e : *it) {
r = add_token(r, e);
}
}
if (auto it = notation_ext::get_entries(env, n)) {
found = true;
for (notation_entry e : *it) {
e.m_overload = false;
r = add_notation(r, e);
}
}
if (auto it = mpz_notation_ext::get_entries(env, n)) {
found = true;
for (mpz_notation_entry e : *it) {
e.m_overload = false;
r = add_mpz_notation(r, e);
}
}
if (!found)
throw exception(sstream() << "unknown namespace '" << n << "'");
return r;
}
void initialize_parser_config() {
token_config::g_class_name = new name("notation");
token_config::g_key = new std::string("tk");
@ -273,8 +347,14 @@ void initialize_parser_config() {
notation_config::g_key = new std::string("nota");
notation_ext::initialize();
g_ext = new cmd_ext_reg();
mpz_notation_config::g_class_name = new name("notation");
mpz_notation_config::g_key = new std::string("numnota");
mpz_notation_ext::initialize();
}
void finalize_parser_config() {
mpz_notation_ext::finalize();
delete mpz_notation_config::g_key;
delete mpz_notation_config::g_class_name;
delete g_ext;
notation_ext::finalize();
delete notation_config::g_key;

View file

@ -36,8 +36,10 @@ environment add_token(environment const & env, token_entry const & e);
environment add_notation(environment const & env, notation_entry const & e);
environment add_token(environment const & env, char const * val, unsigned prec);
environment add_nud_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a, bool overload = true);
environment add_led_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a, bool overload = true);
environment add_nud_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a,
bool overload = true);
environment add_led_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a,
bool overload = true);
environment add_nud_notation(environment const & env, std::initializer_list<notation::transition> const & ts, expr const & a,
bool overload = true);
environment add_led_notation(environment const & env, std::initializer_list<notation::transition> const & ts, expr const & a,
@ -49,6 +51,14 @@ cmd_table const & get_cmd_table(environment const & env);
/** \brief Force notation from namespace \c n to shadow any existing notation */
environment overwrite_notation(environment const & env, name const & n);
/** \brief Add \c n as notation for \c e */
environment add_mpz_notation(environment const & env, mpz const & n, expr const & e, bool overload = true);
/** \brief Return the additional interpretations for \c n in the current environment.
\remark It does not include the default one based on the \c num inductive datatype.
*/
list<expr> get_mpz_notation(environment const & env, mpz const & n);
void initialize_parser_config();
void finalize_parser_config();
}

View file

@ -222,6 +222,10 @@ public:
friend std::ostream & operator<<(std::ostream & out, mpz const & v);
};
struct mpz_cmp_fn {
int operator()(mpz const & v1, mpz const & v2) const { return cmp(v1, v2); }
};
template<>
class numeric_traits<mpz> {
public:

24
tests/lean/num2.lean Normal file
View file

@ -0,0 +1,24 @@
set_option pp.notation false
definition Prop := Type.{0}
variable eq {A : Type} : A → A → Prop
infixl `=`:50 := eq
variable N : Type.{1}
variable z : N
variable o : N
variable b : N
notation 0 := z
notation 1 := o
check 1
check 0
variable G : Type.{1}
variable gz : G
variable a : G
notation 0 := gz
check 0 = a
check b = 0

View file

@ -0,0 +1,4 @@
o : N
z : N
eq gz a : Prop
eq b z : Prop

14
tests/lean/num3.lean Normal file
View file

@ -0,0 +1,14 @@
import data.num
set_option pp.notation false
set_option pp.implicit true
variable N : Type.{1}
variable z : N
variable o : N
variable a : N
notation 0 := z
notation 1 := o
check a = 0
check 2 = 1

View file

@ -0,0 +1,2 @@
@eq N a z : Prop
@eq num 2 1 : Prop

20
tests/lean/num4.lean Normal file
View file

@ -0,0 +1,20 @@
import data.num
set_option pp.notation false
set_option pp.implicit true
namespace foo
variable N : Type.{1}
variable z : N
variable o : N
variable a : N
notation 0 := z
notation 1 := o
check a = 0
end foo
check 2 = 1
check #foo foo.a = 1
open foo
check a = 1

View file

@ -0,0 +1,4 @@
@eq N a z : Prop
@eq num 2 1 : Prop
@eq foo.N foo.a foo.o : Prop
@eq N a o : Prop