diff --git a/src/frontends/lean/notation_cmd.cpp b/src/frontends/lean/notation_cmd.cpp index b9778a3c8..4ab8f1b89 100644 --- a/src/frontends/lean/notation_cmd.cpp +++ b/src/frontends/lean/notation_cmd.cpp @@ -309,7 +309,9 @@ static action parse_action(parser & p, name const & prev_token, unsigned default locals.pop_back(); locals.pop_back(); } - expr ini = parse_notation_expr(p, locals); + optional ini; + if (!p.curr_is_token(get_rparen_tk()) && !p.curr_is_quoted_symbol()) + ini = parse_notation_expr(p, locals); optional terminator; if (!p.curr_is_token(get_rparen_tk())) terminator = parse_quoted_symbol_or_token(p, new_tokens); diff --git a/src/frontends/lean/parse_table.cpp b/src/frontends/lean/parse_table.cpp index 317196591..ebc50b4b0 100644 --- a/src/frontends/lean/parse_table.cpp +++ b/src/frontends/lean/parse_table.cpp @@ -83,10 +83,10 @@ struct expr_action_cell : public action_cell { struct exprs_action_cell : public expr_action_cell { name m_token_sep; expr m_rec; - expr m_ini; + optional m_ini; optional m_terminator; bool m_fold_right; - exprs_action_cell(name const & sep, expr const & rec, expr const & ini, + exprs_action_cell(name const & sep, expr const & rec, optional const & ini, optional const & terminator, bool right, unsigned rbp): expr_action_cell(action_kind::Exprs, rbp), m_token_sep(sep), m_rec(rec), m_ini(ini), m_terminator(terminator), m_fold_right(right) {} @@ -151,7 +151,7 @@ expr const & action::get_rec() const { return to_exprs_action(m_ptr)->m_rec; } bool action::use_lambda_abstraction() const { return to_scoped_expr_action(m_ptr)->m_lambda; } -expr const & action::get_initial() const { return to_exprs_action(m_ptr)->m_ini; } +optional const & action::get_initial() const { return to_exprs_action(m_ptr)->m_ini; } bool action::is_fold_right() const { return to_exprs_action(m_ptr)->m_fold_right; } parse_fn const & action::get_parse_fn() const { return to_ext_action(m_ptr)->m_parse_fn; } std::string const & action::get_lua_fn() const { return to_ext_lua_action(m_ptr)->m_lua_fn; } @@ -193,7 +193,9 @@ void action::display(std::ostream & out) const { out << "(fold" << (is_fold_right() ? "r" : "l"); if (get_terminator()) out << "*"; - out << " " << rbp() << " " << get_rec() << " " << get_initial(); + out << " " << rbp() << " " << get_rec(); + if (get_initial()) + out << " " << *get_initial(); if (get_terminator()) out << *get_terminator(); out << ")"; @@ -239,13 +241,13 @@ void finalize_parse_table() { } action mk_expr_action(unsigned rbp) { return action(new expr_action_cell(rbp)); } -action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, +action mk_exprs_action(name const & sep, expr const & rec, optional const & ini, optional const & terminator, bool right, unsigned rbp) { if (get_free_var_range(rec) > 2) throw exception("invalid notation, the expression used to combine a sequence of expressions " "must not contain free variables with de Bruijn indices greater than 1"); expr new_rec = annotate_macro_subterms(rec); - expr new_ini = annotate_macro_subterms(ini); + optional new_ini = ini ? some_expr(annotate_macro_subterms(*ini)) : none_expr(); return action(new exprs_action_cell(sep, new_rec, new_ini, terminator, right, rbp)); } action mk_scoped_expr_action(expr const & rec, unsigned rb, bool lambda) { @@ -269,7 +271,7 @@ action replace(action const & a, std::function const & f) { case action_kind::Ext: case action_kind::LuaExt: case action_kind::Expr: return a; case action_kind::Exprs: - return mk_exprs_action(a.get_sep(), f(a.get_rec()), f(a.get_initial()), a.get_terminator(), + return mk_exprs_action(a.get_sep(), f(a.get_rec()), a.get_initial() ? some_expr(f(*a.get_initial())) : none_expr(), a.get_terminator(), a.is_fold_right(), a.rbp()); case action_kind::ScopedExpr: return mk_scoped_expr_action(f(a.get_rec()), a.rbp(), a.use_lambda_abstraction()); @@ -498,7 +500,7 @@ static int mk_exprs_action(lua_State * L) { if (nargs >= 4) terminator = to_optional_name(L, 4); return push_notation_action(L, mk_exprs_action(to_name_ext(L, 1), to_expr(L, 2), - to_expr(L, 3), + lua_isnil(L, 3) ? none_expr() : some_expr(to_expr(L, 3)), terminator, lua_toboolean(L, 5), rbp)); @@ -540,7 +542,7 @@ static int rec(lua_State * L) { } static int initial(lua_State * L) { check_action(L, 1, { action_kind::Exprs }); - return push_expr(L, to_notation_action(L, 1).get_initial()); + return push_optional_expr(L, to_notation_action(L, 1).get_initial()); } static int is_fold_right(lua_State * L) { check_action(L, 1, { action_kind::Exprs }); diff --git a/src/frontends/lean/parse_table.h b/src/frontends/lean/parse_table.h index ffe210dd7..5ca90ae94 100644 --- a/src/frontends/lean/parse_table.h +++ b/src/frontends/lean/parse_table.h @@ -60,7 +60,7 @@ public: friend void initialize_parse_table(); friend action mk_expr_action(unsigned rbp); - friend action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, + friend action mk_exprs_action(name const & sep, expr const & rec, optional const & ini, optional const & terminator, bool right, unsigned rbp); friend action mk_scoped_expr_action(expr const & rec, unsigned rbp, bool lambda); friend action mk_ext_action_core(parse_fn const & fn); @@ -71,7 +71,7 @@ public: unsigned rbp() const; name const & get_sep() const; expr const & get_rec() const; - expr const & get_initial() const; + optional const & get_initial() const; optional const & get_terminator() const; bool is_fold_right() const; bool use_lambda_abstraction() const; @@ -89,7 +89,7 @@ inline bool operator!=(action const & a1, action const & a2) { return !a1.is_equ action mk_skip_action(); action mk_expr_action(unsigned rbp = 0); -action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, optional const & terminator, bool right, +action mk_exprs_action(name const & sep, expr const & rec, optional const & ini, optional const & terminator, bool right, unsigned rbp = 0); action mk_binder_action(); action mk_binders_action(); diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index 6074c73d1..406a83eb4 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -903,9 +903,15 @@ expr parser::parse_notation(parse_table t, expr * left) { throw parser_error(sstream() << "invalid composite expression, '" << *terminator << "' expected", pos()); } } - expr r = instantiate_rev(copy_with_new_pos(a.get_initial(), p), args.size(), args.data()); expr rec = copy_with_new_pos(a.get_rec(), p); + expr r; if (a.is_fold_right()) { + if (a.get_initial()) { + r = instantiate_rev(copy_with_new_pos(*a.get_initial(), p), args.size(), args.data()); + } else { + r = r_args.back(); + r_args.pop_back(); + } unsigned i = r_args.size(); while (i > 0) { --i; @@ -915,7 +921,14 @@ expr parser::parse_notation(parse_table t, expr * left) { args.pop_back(); args.pop_back(); } } else { - for (unsigned i = 0; i < r_args.size(); i++) { + unsigned fidx = 0; + if (a.get_initial()) { + r = instantiate_rev(copy_with_new_pos(*a.get_initial(), p), args.size(), args.data()); + } else { + r = r_args[0]; + fidx++; + } + for (unsigned i = fidx; i < r_args.size(); i++) { args.push_back(r_args[i]); args.push_back(r); r = instantiate_rev(rec, args.size(), args.data()); diff --git a/src/frontends/lean/parser_config.cpp b/src/frontends/lean/parser_config.cpp index f56b1e681..7ea23d0f5 100644 --- a/src/frontends/lean/parser_config.cpp +++ b/src/frontends/lean/parser_config.cpp @@ -128,7 +128,12 @@ serializer & operator<<(serializer & s, action const & a) { s << a.rbp(); break; case action_kind::Exprs: - s << a.get_sep() << a.get_rec() << a.get_initial() << a.is_fold_right() << a.rbp(); + s << a.get_sep() << a.get_rec(); + if (a.get_initial()) + s << true << *a.get_initial(); + else + s << false; + s << a.is_fold_right() << a.rbp(); if (auto t = a.get_terminator()) { s << true << *t; } else { @@ -161,8 +166,12 @@ action read_action(deserializer & d) { d >> rbp; return notation::mk_expr_action(rbp); case action_kind::Exprs: { - name sep; expr rec, ini; bool is_fold_right; - d >> sep >> rec >> ini >> is_fold_right >> rbp; + name sep; expr rec; optional ini; bool is_fold_right; + d >> sep >> rec; + if (d.read_bool()) { + ini = read_expr(d); + } + d >> is_fold_right >> rbp; optional terminator; if (d.read_bool()) terminator = read_name(d); diff --git a/src/frontends/lean/pp.cpp b/src/frontends/lean/pp.cpp index a8e9d39a6..286458718 100644 --- a/src/frontends/lean/pp.cpp +++ b/src/frontends/lean/pp.cpp @@ -808,7 +808,7 @@ auto pretty_fn::pp_notation(notation_entry const & entry, buffer> expr e = *args.back(); args.pop_back(); expr const & rec = a.get_rec(); - expr const & ini = a.get_initial(); + optional const & ini = a.get_initial(); buffer rec_args; bool matched_once = false; while (true) { @@ -818,7 +818,6 @@ auto pretty_fn::pp_notation(notation_entry const & entry, buffer> args.pop_back(); break; } - matched_once = true; optional new_e = args.back(); args.pop_back(); optional rec_arg = args.back(); @@ -827,11 +826,16 @@ auto pretty_fn::pp_notation(notation_entry const & entry, buffer> return optional(); rec_args.push_back(*rec_arg); e = *new_e; + matched_once = true; } if (!matched_once) return optional(); - if (!match(ini, e, args)) - return optional(); + if (ini) { + if (!match(*ini, e, args)) + return optional(); + } else { + rec_args.push_back(e); + } if (!a.is_fold_right()) std::reverse(rec_args.begin(), rec_args.end()); unsigned curr_lbp = token_lbp; @@ -847,9 +851,14 @@ auto pretty_fn::pp_notation(notation_entry const & entry, buffer> while (i > 0) { --i; result arg_res = pp_notation_child(rec_args[i], curr_lbp, a.rbp()); - if (i == 0) - sep_fmt = format(tk); - curr = sep_fmt + space() + arg_res.fmt() + curr; + if (i == 0) { + if (add_extra_space(tk)) + curr = format(tk) + space() + arg_res.fmt() + curr; + else + curr = format(tk) + arg_res.fmt() + curr; + } else { + curr = sep_fmt + space() + arg_res.fmt() + curr; + } if (i > 0 && add_extra_space(a.get_sep())) curr = space() + curr; curr_lbp = sep_lbp; diff --git a/tests/lean/fold.lean b/tests/lean/fold.lean new file mode 100644 index 000000000..1aa2f5039 --- /dev/null +++ b/tests/lean/fold.lean @@ -0,0 +1,31 @@ +import data.prod data.num + +variables a b c : num + +context + notation `(` t:(foldr `,` (e r, prod.mk e r)) `)` := t + check (a, false, b, true, c) + set_option pp.notation false + check (a, false, b, true, c) +end + +context + notation `(` t:(foldr `,` (e r, prod.mk r e)) `)` := t + check (a, false, b, true, c) + set_option pp.notation false + check (a, false, b, true, c) +end + +context + notation `(` t:(foldl `,` (e r, prod.mk r e)) `)` := t + check (a, false, b, true, c) + set_option pp.notation false + check (a, false, b, true, c) +end + +context + notation `(` t:(foldl `,` (e r, prod.mk e r)) `)` := t + check (a, false, b, true, c) + set_option pp.notation false + check (a, false, b, true, c) +end diff --git a/tests/lean/fold.lean.expected.out b/tests/lean/fold.lean.expected.out new file mode 100644 index 000000000..be434df08 --- /dev/null +++ b/tests/lean/fold.lean.expected.out @@ -0,0 +1,8 @@ +(a, false, b, true, c) : prod num (prod Prop (prod num (prod Prop num))) +prod.mk a (prod.mk false (prod.mk b (prod.mk true c))) : prod num (prod Prop (prod num (prod Prop num))) +(a, false, b, true, c) : prod (prod (prod (prod num Prop) num) Prop) num +prod.mk (prod.mk (prod.mk (prod.mk c true) b) false) a : prod (prod (prod (prod num Prop) num) Prop) num +(a, false, b, true, c) : prod (prod (prod (prod num Prop) num) Prop) num +prod.mk (prod.mk (prod.mk (prod.mk a false) b) true) c : prod (prod (prod (prod num Prop) num) Prop) num +(a, false, b, true, c) : prod num (prod Prop (prod num (prod Prop num))) +prod.mk c (prod.mk true (prod.mk b (prod.mk false a))) : prod num (prod Prop (prod num (prod Prop num)))