diff --git a/src/frontends/lean/notation_cmd.cpp b/src/frontends/lean/notation_cmd.cpp index 1eea2520f..aa0978bd1 100644 --- a/src/frontends/lean/notation_cmd.cpp +++ b/src/frontends/lean/notation_cmd.cpp @@ -47,59 +47,11 @@ static unsigned parse_precedence(parser & p, char const * msg) { return *r; } -environment precedence_cmd(parser & p) { - std::string tk = parse_symbol(p, "invalid precedence declaration, quoted symbol or identifier expected"); - p.check_token_next(get_colon_tk(), "invalid precedence declaration, ':' expected"); - unsigned prec = parse_precedence(p, "invalid precedence declaration, numeral or 'max' expected"); - return add_token(p.env(), tk.c_str(), prec); -} +LEAN_THREAD_VALUE(bool, g_allow_local, false); -/** \brief Auxiliary function for #cleanup_section_notation. */ -expr cleanup_section_notation_core(parser & p, expr const & e) { - if (is_as_atomic(e)) { - return cleanup_section_notation_core(p, get_as_atomic_arg(e)); - } else if (is_explicit(e)) { - return cleanup_section_notation_core(p, get_explicit_arg(e)); - } else if (is_app(e)) { - buffer args; - expr const & f = get_app_args(e, args); - for (expr arg : args) { - if (!is_explicit(arg)) - throw parser_error("unexpected expression in (section) notation", p.pos_of(arg)); - arg = get_explicit_arg(arg); - if (!is_local(arg)) - throw parser_error("unexpected expression in (section) notation", p.pos_of(arg)); - binder_info bi = local_info(arg); - if (!bi.is_strict_implicit() && !bi.is_implicit()) - throw parser_error(sstream() << "invalid occurrence of local parameter '" << local_pp_name(arg) - << "' in (section) notation that is not implicit", p.pos_of(e)); - } - return cleanup_section_notation_core(p, f); - } else if (is_constant(e)) { - return p.save_pos(mk_constant(const_name(e)), p.pos_of(e)); - } else if (is_local(e)) { - throw parser_error(sstream() << "invalid occurrence of local parameter '" << local_pp_name(e) - << "' in (section) notation", p.pos_of(e)); - } else { - throw parser_error("unexpected expression in (section) notation", p.pos_of(e)); - } -} - -/** \brief Replace reference to implicit section local constants and universes with placeholders. - - \remark Throws an exception if \c e contains a local constant that is not implicit. -*/ -expr cleanup_section_notation(parser & p, expr const & e) { - if (!in_section(p.env())) - return e; - return replace(e, [&](expr const & e) { - if (is_local(e)) - throw parser_error(sstream() << "invalid occurrence of local parameter '" << local_pp_name(e) - << "' in (section) notation", p.pos_of(e)); - if (is_as_atomic(e)) - return some_expr(cleanup_section_notation_core(p, e)); - return none_expr(); - }); +static void check_notation_expr(parser & p, expr const & e, pos_info const & pos) { + if (!g_allow_local && (has_local(e) || has_param_univ(e))) + throw parser_error("invalid notation declaration, contains reference to local variables", pos); } enum class mixfix_kind { infixl, infixr, postfix, prefix }; @@ -135,17 +87,23 @@ static pair> parse_mixfix_notation(parser if (k == mixfix_kind::infixr && *prec == 0) throw parser_error("invalid infixr declaration, precedence must be greater than zero", p.pos()); p.check_token_next(get_assign_tk(), "invalid notation declaration, ':=' expected"); - expr f = cleanup_section_notation(p, p.parse_expr()); + auto f_pos = p.pos(); + expr f = p.parse_expr(); + check_notation_expr(p, f, f_pos); char const * tks = tk.c_str(); switch (k) { case mixfix_kind::infixl: - return mk_pair(notation_entry(false, to_list(transition(tks, mk_expr_action(*prec))), mk_app(f, Var(1), Var(0)), overload), new_token); + return mk_pair(notation_entry(false, to_list(transition(tks, mk_expr_action(*prec))), + mk_app(f, Var(1), Var(0)), overload), new_token); case mixfix_kind::infixr: - return mk_pair(notation_entry(false, to_list(transition(tks, mk_expr_action(*prec-1))), mk_app(f, Var(1), Var(0)), overload), new_token); + return mk_pair(notation_entry(false, to_list(transition(tks, mk_expr_action(*prec-1))), + mk_app(f, Var(1), Var(0)), overload), new_token); case mixfix_kind::postfix: - return mk_pair(notation_entry(false, to_list(transition(tks, mk_skip_action())), mk_app(f, Var(0)), overload), new_token); + return mk_pair(notation_entry(false, to_list(transition(tks, mk_skip_action())), + mk_app(f, Var(0)), overload), new_token); case mixfix_kind::prefix: - return mk_pair(notation_entry(true, to_list(transition(tks, mk_expr_action(*prec))), mk_app(f, Var(0)), overload), new_token); + return mk_pair(notation_entry(true, to_list(transition(tks, mk_expr_action(*prec))), + mk_app(f, Var(0)), overload), new_token); } lean_unreachable(); // LCOV_EXCL_LINE } @@ -157,20 +115,6 @@ static notation_entry parse_mixfix_notation(parser & p, mixfix_kind k, bool over return nt.first; } -static environment mixfix_cmd(parser & p, mixfix_kind k, bool overload) { - auto nt = parse_mixfix_notation(p, k, overload); - environment env = p.env(); - if (nt.second) - env = add_token(env, *nt.second); - env = add_notation(env, nt.first); - return env; -} - -environment infixl_cmd_core(parser & p, bool overload) { return mixfix_cmd(p, mixfix_kind::infixl, overload); } -environment infixr_cmd_core(parser & p, bool overload) { return mixfix_cmd(p, mixfix_kind::infixr, overload); } -environment postfix_cmd_core(parser & p, bool overload) { return mixfix_cmd(p, mixfix_kind::postfix, overload); } -environment prefix_cmd_core(parser & p, bool overload) { return mixfix_cmd(p, mixfix_kind::prefix, overload); } - static name parse_quoted_symbol_or_token(parser & p, buffer & new_tokens) { if (p.curr_is_quoted_symbol()) { environment const & env = p.env(); @@ -198,8 +142,11 @@ static name parse_quoted_symbol_or_token(parser & p, buffer & new_t } static expr parse_notation_expr(parser & p, buffer const & locals) { + auto pos = p.pos(); expr r = p.parse_expr(); - return cleanup_section_notation(p, abstract(r, locals.size(), locals.data())); + r = abstract(r, locals.size(), locals.data()); + check_notation_expr(p, r, pos); + return r; } static void parse_notation_local(parser & p, buffer & locals) { @@ -215,7 +162,7 @@ static void parse_notation_local(parser & p, buffer & locals) { } } -unsigned get_precedence(environment const & env, buffer const & new_tokens, name const & token) { +static unsigned get_precedence(environment const & env, buffer const & new_tokens, name const & token) { std::string token_str = token.to_string(); for (auto const & e : new_tokens) { if (e.m_token == token_str) @@ -228,7 +175,8 @@ unsigned get_precedence(environment const & env, buffer const & new return 0; } -static action parse_action(parser & p, name const & prev_token, unsigned default_prec, buffer & locals, buffer & new_tokens) { +static action parse_action(parser & p, name const & prev_token, unsigned default_prec, + buffer & locals, buffer & new_tokens) { if (p.curr_is_token(get_colon_tk())) { p.next(); if (p.curr_is_numeral() || p.curr_is_token_or_id(get_max_tk())) { @@ -302,7 +250,7 @@ static action parse_action(parser & p, name const & prev_token, unsigned default when the user does not provide it. The idea is to minimize conflict with existing notation. */ -unsigned get_default_prec(optional const & pt, name const & tk) { +static unsigned get_default_prec(optional const & pt, name const & tk) { if (!pt) return 0; if (auto at = pt->find(tk)) { @@ -312,7 +260,7 @@ unsigned get_default_prec(optional const & pt, name const & tk) { return 0; } -notation_entry parse_notation_core(parser & p, bool overload, buffer & new_tokens) { +static notation_entry parse_notation_core(parser & p, bool overload, buffer & new_tokens) { buffer locals; buffer ts; parser::local_scope scope(p); @@ -323,7 +271,9 @@ notation_entry parse_notation_core(parser & p, bool overload, buffer new_tokens; - auto ne = parse_notation_core(p, overload, new_tokens); - for (auto const & te : new_tokens) - env = add_token(env, te); - env = add_notation(env, ne); - return env; -} - bool curr_is_notation_decl(parser & p) { - return p.curr_is_token(get_infix_tk()) || p.curr_is_token(get_infixl_tk()) || p.curr_is_token(get_infixr_tk()) || + return + p.curr_is_token(get_infix_tk()) || p.curr_is_token(get_infixl_tk()) || p.curr_is_token(get_infixr_tk()) || p.curr_is_token(get_postfix_tk()) || p.curr_is_token(get_prefix_tk()) || p.curr_is_token(get_notation_tk()); } -notation_entry parse_notation(parser & p, bool overload, buffer & new_tokens) { +notation_entry parse_notation(parser & p, bool overload, buffer & new_tokens, bool allow_local) { + flet set_allow_local(g_allow_local, allow_local); if (p.curr_is_token(get_infix_tk()) || p.curr_is_token(get_infixl_tk())) { p.next(); return parse_mixfix_notation(p, mixfix_kind::infixl, overload, new_tokens); @@ -414,11 +356,41 @@ notation_entry parse_notation(parser & p, bool overload, buffer & n } } -environment notation_cmd(parser & p) { return notation_cmd_core(p, !in_context(p.env())); } -environment infixl_cmd(parser & p) { return infixl_cmd_core(p, !in_context(p.env())); } -environment infixr_cmd(parser & p) { return infixr_cmd_core(p, !in_context(p.env())); } -environment postfix_cmd(parser & p) { return postfix_cmd_core(p, !in_context(p.env())); } -environment prefix_cmd(parser & p) { return prefix_cmd_core(p, !in_context(p.env())); } +static environment notation_cmd_core(parser & p, bool overload) { + flet set_allow_local(g_allow_local, in_context(p.env())); + environment env = p.env(); + buffer new_tokens; + auto ne = parse_notation_core(p, overload, new_tokens); + for (auto const & te : new_tokens) + env = add_token(env, te); + env = add_notation(env, ne); + return env; +} +static environment mixfix_cmd(parser & p, mixfix_kind k, bool overload) { + flet set_allow_local(g_allow_local, in_context(p.env())); + auto nt = parse_mixfix_notation(p, k, overload); + environment env = p.env(); + if (nt.second) + env = add_token(env, *nt.second); + env = add_notation(env, nt.first); + return env; +} +static environment infixl_cmd_core(parser & p, bool overload) { return mixfix_cmd(p, mixfix_kind::infixl, overload); } +static environment infixr_cmd_core(parser & p, bool overload) { return mixfix_cmd(p, mixfix_kind::infixr, overload); } +static environment postfix_cmd_core(parser & p, bool overload) { return mixfix_cmd(p, mixfix_kind::postfix, overload); } +static environment prefix_cmd_core(parser & p, bool overload) { return mixfix_cmd(p, mixfix_kind::prefix, overload); } +static environment notation_cmd(parser & p) { return notation_cmd_core(p, !in_context(p.env())); } +static environment infixl_cmd(parser & p) { return infixl_cmd_core(p, !in_context(p.env())); } +static environment infixr_cmd(parser & p) { return infixr_cmd_core(p, !in_context(p.env())); } +static environment postfix_cmd(parser & p) { return postfix_cmd_core(p, !in_context(p.env())); } +static environment prefix_cmd(parser & p) { return prefix_cmd_core(p, !in_context(p.env())); } + +static environment precedence_cmd(parser & p) { + std::string tk = parse_symbol(p, "invalid precedence declaration, quoted symbol or identifier expected"); + p.check_token_next(get_colon_tk(), "invalid precedence declaration, ':' expected"); + unsigned prec = parse_precedence(p, "invalid precedence declaration, numeral or 'max' expected"); + return add_token(p.env(), tk.c_str(), prec); +} void register_notation_cmds(cmd_table & r) { add_cmd(r, cmd_info("precedence", "set token left binding power", precedence_cmd)); diff --git a/src/frontends/lean/notation_cmd.h b/src/frontends/lean/notation_cmd.h index 4d895c958..a0d31bfc1 100644 --- a/src/frontends/lean/notation_cmd.h +++ b/src/frontends/lean/notation_cmd.h @@ -12,8 +12,10 @@ namespace lean { class parser; /** \brief Return true iff the current token is a notation declaration */ bool curr_is_notation_decl(parser & p); -/** \brief Parse a notation declaration, throws an error if the current token is not a "notation declaration". */ -notation_entry parse_notation(parser & p, bool overload, buffer & new_tokens); +/** \brief Parse a notation declaration, throws an error if the current token is not a "notation declaration". + If allow_local is true, then notation may contain reference to local constants. +*/ +notation_entry parse_notation(parser & p, bool overload, buffer & new_tokens, bool allow_local); void register_notation_cmds(cmd_table & r); } diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index df10dbcad..72b4cfd15 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -826,7 +826,9 @@ local_environment parser::parse_binders(buffer & r, buffer bool parser::parse_local_notation_decl(buffer * nentries) { if (curr_is_notation_decl(*this)) { buffer new_tokens; - auto ne = ::lean::parse_notation(*this, false, new_tokens); + bool overload = false; + bool allow_local = true; + auto ne = ::lean::parse_notation(*this, overload, new_tokens, allow_local); for (auto const & te : new_tokens) m_env = add_token(m_env, te); if (nentries) nentries->push_back(ne); diff --git a/tests/lean/bad_notation.lean b/tests/lean/bad_notation.lean new file mode 100644 index 000000000..acb2ed317 --- /dev/null +++ b/tests/lean/bad_notation.lean @@ -0,0 +1,9 @@ +import logic data.nat.basic +open nat + +section + variable a : nat + notation `a1`:max := a + 1 +end + +definition foo := a1 diff --git a/tests/lean/bad_notation.lean.expected.out b/tests/lean/bad_notation.lean.expected.out new file mode 100644 index 000000000..4a5c87283 --- /dev/null +++ b/tests/lean/bad_notation.lean.expected.out @@ -0,0 +1,2 @@ +bad_notation.lean:6:23: error: invalid notation declaration, contains reference to local variables +bad_notation.lean:9:18: error: unknown identifier 'a1'