diff --git a/src/frontends/lean/notation_cmd.cpp b/src/frontends/lean/notation_cmd.cpp index 732abe8cb..2a2ccc60b 100644 --- a/src/frontends/lean/notation_cmd.cpp +++ b/src/frontends/lean/notation_cmd.cpp @@ -371,14 +371,27 @@ notation_entry parse_notation_core(parser & p, bool overload, buffer new_tokens; - auto ne = parse_notation_core(p, overload, new_tokens); - for (auto const & te : new_tokens) - env = add_token(env, te); - env = add_notation(env, ne); - return env; + if (p.curr_is_numeral()) { + return parse_num_notation(p, overload); + } else { + auto ne = parse_notation_core(p, overload, new_tokens); + for (auto const & te : new_tokens) + env = add_token(env, te); + env = add_notation(env, ne); + return env; + } } bool curr_is_notation_decl(parser & p) { diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index 225efd81d..f96069132 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -1058,10 +1058,22 @@ expr parser::parse_numeral_expr() { next(); if (!m_has_num) m_has_num = has_num_decls(m_env); - if (!*m_has_num) + list vals = get_mpz_notation(m_env, n); + if (!*m_has_num && !vals) { throw parser_error("numeral cannot be encoded as expression, environment does not contain the type 'num' " - "(solution: use 'import num')", p); - return from_num(n); + "nor notation was defined for the given numeral " + "(solution: use 'import data.num', or define notation for the given numeral)", p); + } + buffer cs; + for (expr const & c : vals) + cs.push_back(copy_with_new_pos(c, p)); + if (*m_has_num) + cs.push_back(save_pos(from_num(n), p)); + lean_assert(!cs.empty()); + if (cs.size() == 1) + return cs[0]; + else + return save_pos(mk_choice(cs.size(), cs.data()), p); } expr parser::parse_decimal_expr() { diff --git a/src/frontends/lean/parser_config.cpp b/src/frontends/lean/parser_config.cpp index 726055b40..69e370df0 100644 --- a/src/frontends/lean/parser_config.cpp +++ b/src/frontends/lean/parser_config.cpp @@ -217,27 +217,6 @@ environment add_led_notation(environment const & env, std::initializer_list, mpz_cmp_fn> map; + map m_map; +}; + +struct mpz_notation_config { + typedef mpz_notation_state state; + typedef mpz_notation_entry entry; + static name * g_class_name; + static std::string * g_key; + + static void add_entry(environment const &, io_state const &, state & s, entry const & e) { + if (!e.m_overload) { + s.m_map.insert(e.m_num, list(e.m_expr)); + } else if (auto it = s.m_map.find(e.m_num)) { + list new_exprs = cons(e.m_expr, filter(*it, [&](expr const & n) { return n != e.m_expr; })); + s.m_map.insert(e.m_num, new_exprs); + } else { + s.m_map.insert(e.m_num, list(e.m_expr)); + } + } + static name const & get_class_name() { + return *g_class_name; + } + static std::string const & get_serialization_key() { + return *g_key; + } + static void write_entry(serializer & s, entry const & e) { + s << e.m_num << e.m_expr << e.m_overload; + } + static entry read_entry(deserializer & d) { + entry e; + d >> e.m_num >> e.m_expr >> e.m_overload; + return e; + } +}; + +name * mpz_notation_config::g_class_name = nullptr; +std::string * mpz_notation_config::g_key = nullptr; + +template class scoped_ext; +typedef scoped_ext mpz_notation_ext; + +environment add_mpz_notation(environment const & env, mpz_notation_entry const & e) { + return mpz_notation_ext::add_entry(env, get_dummy_ios(), e); +} + +environment add_mpz_notation(environment const & env, mpz const & n, expr const & e, bool overload) { + return add_mpz_notation(env, mpz_notation_entry(n, e, overload)); +} + +list get_mpz_notation(environment const & env, mpz const & n) { + if (auto it = mpz_notation_ext::get_state(env).m_map.find(n)) { + return *it; + } else { + return list(); + } +} + +environment overwrite_notation(environment const & env, name const & n) { + environment r = env; + bool found = false; + if (auto it = token_ext::get_entries(r, n)) { + found = true; + for (token_entry e : *it) { + r = add_token(r, e); + } + } + if (auto it = notation_ext::get_entries(env, n)) { + found = true; + for (notation_entry e : *it) { + e.m_overload = false; + r = add_notation(r, e); + } + } + if (auto it = mpz_notation_ext::get_entries(env, n)) { + found = true; + for (mpz_notation_entry e : *it) { + e.m_overload = false; + r = add_mpz_notation(r, e); + } + } + if (!found) + throw exception(sstream() << "unknown namespace '" << n << "'"); + return r; +} + void initialize_parser_config() { token_config::g_class_name = new name("notation"); token_config::g_key = new std::string("tk"); @@ -273,8 +347,14 @@ void initialize_parser_config() { notation_config::g_key = new std::string("nota"); notation_ext::initialize(); g_ext = new cmd_ext_reg(); + mpz_notation_config::g_class_name = new name("notation"); + mpz_notation_config::g_key = new std::string("numnota"); + mpz_notation_ext::initialize(); } void finalize_parser_config() { + mpz_notation_ext::finalize(); + delete mpz_notation_config::g_key; + delete mpz_notation_config::g_class_name; delete g_ext; notation_ext::finalize(); delete notation_config::g_key; diff --git a/src/frontends/lean/parser_config.h b/src/frontends/lean/parser_config.h index 57a16f669..4b38c29b2 100644 --- a/src/frontends/lean/parser_config.h +++ b/src/frontends/lean/parser_config.h @@ -36,8 +36,10 @@ environment add_token(environment const & env, token_entry const & e); environment add_notation(environment const & env, notation_entry const & e); environment add_token(environment const & env, char const * val, unsigned prec); -environment add_nud_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a, bool overload = true); -environment add_led_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a, bool overload = true); +environment add_nud_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a, + bool overload = true); +environment add_led_notation(environment const & env, unsigned num, notation::transition const * ts, expr const & a, + bool overload = true); environment add_nud_notation(environment const & env, std::initializer_list const & ts, expr const & a, bool overload = true); environment add_led_notation(environment const & env, std::initializer_list const & ts, expr const & a, @@ -49,6 +51,14 @@ cmd_table const & get_cmd_table(environment const & env); /** \brief Force notation from namespace \c n to shadow any existing notation */ environment overwrite_notation(environment const & env, name const & n); +/** \brief Add \c n as notation for \c e */ +environment add_mpz_notation(environment const & env, mpz const & n, expr const & e, bool overload = true); +/** \brief Return the additional interpretations for \c n in the current environment. + + \remark It does not include the default one based on the \c num inductive datatype. +*/ +list get_mpz_notation(environment const & env, mpz const & n); + void initialize_parser_config(); void finalize_parser_config(); } diff --git a/src/util/numerics/mpz.h b/src/util/numerics/mpz.h index df0594c05..c872e6207 100644 --- a/src/util/numerics/mpz.h +++ b/src/util/numerics/mpz.h @@ -222,6 +222,10 @@ public: friend std::ostream & operator<<(std::ostream & out, mpz const & v); }; +struct mpz_cmp_fn { + int operator()(mpz const & v1, mpz const & v2) const { return cmp(v1, v2); } +}; + template<> class numeric_traits { public: diff --git a/tests/lean/num2.lean b/tests/lean/num2.lean new file mode 100644 index 000000000..873706673 --- /dev/null +++ b/tests/lean/num2.lean @@ -0,0 +1,24 @@ +set_option pp.notation false +definition Prop := Type.{0} +variable eq {A : Type} : A → A → Prop +infixl `=`:50 := eq + +variable N : Type.{1} +variable z : N +variable o : N +variable b : N + +notation 0 := z +notation 1 := o + +check 1 +check 0 + +variable G : Type.{1} +variable gz : G +variable a : G + +notation 0 := gz + +check 0 = a +check b = 0 diff --git a/tests/lean/num2.lean.expected.out b/tests/lean/num2.lean.expected.out new file mode 100644 index 000000000..d57ec759a --- /dev/null +++ b/tests/lean/num2.lean.expected.out @@ -0,0 +1,4 @@ +o : N +z : N +eq gz a : Prop +eq b z : Prop diff --git a/tests/lean/num3.lean b/tests/lean/num3.lean new file mode 100644 index 000000000..5241cbcc6 --- /dev/null +++ b/tests/lean/num3.lean @@ -0,0 +1,14 @@ +import data.num +set_option pp.notation false +set_option pp.implicit true + +variable N : Type.{1} +variable z : N +variable o : N +variable a : N + +notation 0 := z +notation 1 := o + +check a = 0 +check 2 = 1 diff --git a/tests/lean/num3.lean.expected.out b/tests/lean/num3.lean.expected.out new file mode 100644 index 000000000..77c4caf65 --- /dev/null +++ b/tests/lean/num3.lean.expected.out @@ -0,0 +1,2 @@ +@eq N a z : Prop +@eq num 2 1 : Prop diff --git a/tests/lean/num4.lean b/tests/lean/num4.lean new file mode 100644 index 000000000..9955cc8c3 --- /dev/null +++ b/tests/lean/num4.lean @@ -0,0 +1,20 @@ +import data.num +set_option pp.notation false +set_option pp.implicit true + +namespace foo + variable N : Type.{1} + variable z : N + variable o : N + variable a : N + notation 0 := z + notation 1 := o + + check a = 0 +end foo + +check 2 = 1 +check #foo foo.a = 1 + +open foo +check a = 1 diff --git a/tests/lean/num4.lean.expected.out b/tests/lean/num4.lean.expected.out new file mode 100644 index 000000000..b7030c5c1 --- /dev/null +++ b/tests/lean/num4.lean.expected.out @@ -0,0 +1,4 @@ +@eq N a z : Prop +@eq num 2 1 : Prop +@eq foo.N foo.a foo.o : Prop +@eq N a o : Prop