diff --git a/src/frontends/lean/decl_cmds.cpp b/src/frontends/lean/decl_cmds.cpp index f6001adbd..649d2a311 100644 --- a/src/frontends/lean/decl_cmds.cpp +++ b/src/frontends/lean/decl_cmds.cpp @@ -542,42 +542,45 @@ static void parse_equations_core(parser & p, buffer const & fns, buffer locals; - { - parser::undef_id_to_local_scope scope2(p); - buffer lhs_args; - auto lhs_pos = p.pos(); - if (p.curr_is_token(get_explicit_tk())) { - p.next(); - name fn_name = p.check_decl_id_next("invalid recursive equation, identifier expected"); - lhs_args.push_back(p.save_pos(mk_explicit(get_equation_fn(fns, fn_name, lhs_pos)), lhs_pos)); - } else { - expr first = p.parse_expr(get_max_prec()); - expr fn = first; - if (is_explicit(fn)) - fn = get_explicit_arg(fn); - if (is_local(fn) && is_equation_fn(fns, local_pp_name(fn))) { - lhs_args.push_back(first); - } else if (fns.size() == 1) { - lhs_args.push_back(p.save_pos(mk_explicit(fns[0]), lhs_pos)); - lhs_args.push_back(first); - } else { - throw parser_error("invalid recursive equation, head symbol in left-hand-side is not a constant", - lhs_pos); - } - } - while (!p.curr_is_token(get_assign_tk())) - lhs_args.push_back(p.parse_expr(get_max_prec())); - lean_assert(lhs_args.size() > 0); - lhs = lhs_args[0]; - for (unsigned i = 1; i < lhs_args.size(); i++) - lhs = copy_tag(lhs_args[i], mk_app(lhs, lhs_args[i])); + buffer lhs_args; - unsigned num_undef_ids = p.get_num_undef_ids(); - for (unsigned i = prev_num_undef_ids; i < num_undef_ids; i++) { - locals.push_back(p.get_undef_id(i)); + // check if lhs starts with an (explicit) equation function symbol (which is optional for fns.size() == 1) + optional fn; + auto lhs_pos = p.pos(); + if (p.curr_is_token(get_explicit_tk())) { + p.next(); + name fn_name = p.check_decl_id_next("invalid recursive equation, identifier expected"); + lhs_args.push_back(p.save_pos(mk_explicit(get_equation_fn(fns, fn_name, lhs_pos)), lhs_pos)); + } else if (p.curr_is_identifier() && (fn = is_equation_fn(fns, p.get_name_val()))) { + p.next(); + lhs_args.push_back(p.save_pos(*fn, lhs_pos)); + } else { + if (fns.size() == 1) { + lhs_args.push_back(p.save_pos(mk_explicit(fns[0]), lhs_pos)); + } else { + throw parser_error("invalid recursive equation, head symbol in left-hand-side is not an equation function", + lhs_pos); } } + + // parse the remaining left-hand side + { + parser::local_and_undef_id_to_local_scope scope2(p); + while (!p.curr_is_token(get_assign_tk())) + lhs_args.push_back(p.parse_expr(get_max_prec())); + } + + lean_assert(lhs_args.size() > 0); + lhs = lhs_args[0]; + for (unsigned i = 1; i < lhs_args.size(); i++) + lhs = copy_tag(lhs_args[i], mk_app(lhs, lhs_args[i])); + + buffer locals; + unsigned num_undef_ids = p.get_num_undef_ids(); + for (unsigned i = prev_num_undef_ids; i < num_undef_ids; i++) { + locals.push_back(p.get_undef_id(i)); + } + validate_equation_lhs(p, lhs, locals); lhs = merge_equation_lhs_vars(lhs, locals); auto assign_pos = p.pos(); @@ -677,7 +680,7 @@ expr parse_match(parser & p, unsigned, expr const *, pos_info const & pos) { unsigned prev_num_undef_ids = p.get_num_undef_ids(); buffer locals; { - parser::undef_id_to_local_scope scope2(p); + parser::local_and_undef_id_to_local_scope scope2(p); auto lhs_pos = p.pos(); lhs = p.parse_expr(); lhs = p.mk_app(fn, lhs, lhs_pos); diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index fc6575e2a..932947e60 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -107,6 +107,8 @@ parser::undef_id_to_const_scope::undef_id_to_const_scope(parser & p): flet(p.m_undef_id_behavior, undef_id_behavior::AssumeConstant) {} parser::undef_id_to_local_scope::undef_id_to_local_scope(parser & p): flet(p.m_undef_id_behavior, undef_id_behavior::AssumeLocal) {} +parser::local_and_undef_id_to_local_scope::local_and_undef_id_to_local_scope(parser & p): + flet(p.m_undef_id_behavior, undef_id_behavior::AssumeLocalAndAlsoDefinedLocals) {} static name * g_tmp_prefix = nullptr; @@ -1474,6 +1476,11 @@ expr parser::id_to_expr(name const & id, pos_info const & p) { if (ls && m_undef_id_behavior != undef_id_behavior::AssumeConstant) throw parser_error("invalid use of explicit universe parameter, identifier is a variable, " "parameter or a constant bound to parameters in a section", p); + if (m_undef_id_behavior == undef_id_behavior::AssumeLocalAndAlsoDefinedLocals) { + expr local = mk_local(id, mk_expr_placeholder()); + m_undef_ids.push_back(local); + return save_pos(local, p); + } auto r = copy_with_new_pos(*it1, p); save_type_info(r); save_identifier_info(p, id); @@ -1523,7 +1530,7 @@ expr parser::id_to_expr(name const & id, pos_info const & p) { if (!r) { if (m_undef_id_behavior == undef_id_behavior::AssumeConstant) { r = save_pos(mk_constant(get_namespace(m_env) + id, ls), p); - } else if (m_undef_id_behavior == undef_id_behavior::AssumeLocal) { + } else if (m_undef_id_behavior == undef_id_behavior::AssumeLocal || m_undef_id_behavior == undef_id_behavior::AssumeLocalAndAlsoDefinedLocals) { expr local = mk_local(id, mk_expr_placeholder()); m_undef_ids.push_back(local); r = save_pos(local, p); diff --git a/src/frontends/lean/parser.h b/src/frontends/lean/parser.h index 457832cab..a2fda608f 100644 --- a/src/frontends/lean/parser.h +++ b/src/frontends/lean/parser.h @@ -83,7 +83,7 @@ typedef std::vector snapshot_vector; enum class keep_theorem_mode { All, DiscardImported, DiscardAll }; -enum class undef_id_behavior { Error, AssumeConstant, AssumeLocal }; +enum class undef_id_behavior { Error, AssumeConstant, AssumeLocal, AssumeLocalAndAlsoDefinedLocals }; class parser { environment m_env; @@ -491,6 +491,7 @@ public: */ struct undef_id_to_const_scope : public flet { undef_id_to_const_scope(parser & p); }; struct undef_id_to_local_scope : public flet { undef_id_to_local_scope(parser &); }; + struct local_and_undef_id_to_local_scope : public flet { local_and_undef_id_to_local_scope(parser &); }; /** \brief Return the size of the stack of undefined local constants */ unsigned get_num_undef_ids() const { return m_undef_ids.size(); } diff --git a/tests/lean/run/eq23.lean b/tests/lean/run/eq23.lean index 75419afbe..e82fde354 100644 --- a/tests/lean/run/eq23.lean +++ b/tests/lean/run/eq23.lean @@ -10,7 +10,7 @@ with tree_list := namespace tree_list definition len {A : Type} : tree_list A → nat -| len (nil A) := 0 +| len (nil _) := 0 | len (cons t l) := len l + 1 theorem len_nil {A : Type} : len (nil A) = 0 :=