feat(frontends/lean): add variation of the foldl/foldr notation where initial element is suppressed, closes #314

See tests/lean/fold.lean for examples
This commit is contained in:
Leonardo de Moura 2014-11-09 14:08:33 -08:00
parent ce889ddf60
commit eff3c6b774
8 changed files with 99 additions and 25 deletions

View file

@ -309,7 +309,9 @@ static action parse_action(parser & p, name const & prev_token, unsigned default
locals.pop_back(); locals.pop_back();
locals.pop_back(); locals.pop_back();
} }
expr ini = parse_notation_expr(p, locals); optional<expr> ini;
if (!p.curr_is_token(get_rparen_tk()) && !p.curr_is_quoted_symbol())
ini = parse_notation_expr(p, locals);
optional<name> terminator; optional<name> terminator;
if (!p.curr_is_token(get_rparen_tk())) if (!p.curr_is_token(get_rparen_tk()))
terminator = parse_quoted_symbol_or_token(p, new_tokens); terminator = parse_quoted_symbol_or_token(p, new_tokens);

View file

@ -83,10 +83,10 @@ struct expr_action_cell : public action_cell {
struct exprs_action_cell : public expr_action_cell { struct exprs_action_cell : public expr_action_cell {
name m_token_sep; name m_token_sep;
expr m_rec; expr m_rec;
expr m_ini; optional<expr> m_ini;
optional<name> m_terminator; optional<name> m_terminator;
bool m_fold_right; bool m_fold_right;
exprs_action_cell(name const & sep, expr const & rec, expr const & ini, exprs_action_cell(name const & sep, expr const & rec, optional<expr> const & ini,
optional<name> const & terminator, bool right, unsigned rbp): optional<name> const & terminator, bool right, unsigned rbp):
expr_action_cell(action_kind::Exprs, rbp), expr_action_cell(action_kind::Exprs, rbp),
m_token_sep(sep), m_rec(rec), m_ini(ini), m_terminator(terminator), m_fold_right(right) {} m_token_sep(sep), m_rec(rec), m_ini(ini), m_terminator(terminator), m_fold_right(right) {}
@ -151,7 +151,7 @@ expr const & action::get_rec() const {
return to_exprs_action(m_ptr)->m_rec; return to_exprs_action(m_ptr)->m_rec;
} }
bool action::use_lambda_abstraction() const { return to_scoped_expr_action(m_ptr)->m_lambda; } bool action::use_lambda_abstraction() const { return to_scoped_expr_action(m_ptr)->m_lambda; }
expr const & action::get_initial() const { return to_exprs_action(m_ptr)->m_ini; } optional<expr> const & action::get_initial() const { return to_exprs_action(m_ptr)->m_ini; }
bool action::is_fold_right() const { return to_exprs_action(m_ptr)->m_fold_right; } bool action::is_fold_right() const { return to_exprs_action(m_ptr)->m_fold_right; }
parse_fn const & action::get_parse_fn() const { return to_ext_action(m_ptr)->m_parse_fn; } parse_fn const & action::get_parse_fn() const { return to_ext_action(m_ptr)->m_parse_fn; }
std::string const & action::get_lua_fn() const { return to_ext_lua_action(m_ptr)->m_lua_fn; } std::string const & action::get_lua_fn() const { return to_ext_lua_action(m_ptr)->m_lua_fn; }
@ -193,7 +193,9 @@ void action::display(std::ostream & out) const {
out << "(fold" << (is_fold_right() ? "r" : "l"); out << "(fold" << (is_fold_right() ? "r" : "l");
if (get_terminator()) if (get_terminator())
out << "*"; out << "*";
out << " " << rbp() << " " << get_rec() << " " << get_initial(); out << " " << rbp() << " " << get_rec();
if (get_initial())
out << " " << *get_initial();
if (get_terminator()) if (get_terminator())
out << *get_terminator(); out << *get_terminator();
out << ")"; out << ")";
@ -239,13 +241,13 @@ void finalize_parse_table() {
} }
action mk_expr_action(unsigned rbp) { return action(new expr_action_cell(rbp)); } action mk_expr_action(unsigned rbp) { return action(new expr_action_cell(rbp)); }
action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, action mk_exprs_action(name const & sep, expr const & rec, optional<expr> const & ini,
optional<name> const & terminator, bool right, unsigned rbp) { optional<name> const & terminator, bool right, unsigned rbp) {
if (get_free_var_range(rec) > 2) if (get_free_var_range(rec) > 2)
throw exception("invalid notation, the expression used to combine a sequence of expressions " throw exception("invalid notation, the expression used to combine a sequence of expressions "
"must not contain free variables with de Bruijn indices greater than 1"); "must not contain free variables with de Bruijn indices greater than 1");
expr new_rec = annotate_macro_subterms(rec); expr new_rec = annotate_macro_subterms(rec);
expr new_ini = annotate_macro_subterms(ini); optional<expr> new_ini = ini ? some_expr(annotate_macro_subterms(*ini)) : none_expr();
return action(new exprs_action_cell(sep, new_rec, new_ini, terminator, right, rbp)); return action(new exprs_action_cell(sep, new_rec, new_ini, terminator, right, rbp));
} }
action mk_scoped_expr_action(expr const & rec, unsigned rb, bool lambda) { action mk_scoped_expr_action(expr const & rec, unsigned rb, bool lambda) {
@ -269,7 +271,7 @@ action replace(action const & a, std::function<expr(expr const &)> const & f) {
case action_kind::Ext: case action_kind::LuaExt: case action_kind::Expr: case action_kind::Ext: case action_kind::LuaExt: case action_kind::Expr:
return a; return a;
case action_kind::Exprs: case action_kind::Exprs:
return mk_exprs_action(a.get_sep(), f(a.get_rec()), f(a.get_initial()), a.get_terminator(), return mk_exprs_action(a.get_sep(), f(a.get_rec()), a.get_initial() ? some_expr(f(*a.get_initial())) : none_expr(), a.get_terminator(),
a.is_fold_right(), a.rbp()); a.is_fold_right(), a.rbp());
case action_kind::ScopedExpr: case action_kind::ScopedExpr:
return mk_scoped_expr_action(f(a.get_rec()), a.rbp(), a.use_lambda_abstraction()); return mk_scoped_expr_action(f(a.get_rec()), a.rbp(), a.use_lambda_abstraction());
@ -498,7 +500,7 @@ static int mk_exprs_action(lua_State * L) {
if (nargs >= 4) terminator = to_optional_name(L, 4); if (nargs >= 4) terminator = to_optional_name(L, 4);
return push_notation_action(L, mk_exprs_action(to_name_ext(L, 1), return push_notation_action(L, mk_exprs_action(to_name_ext(L, 1),
to_expr(L, 2), to_expr(L, 2),
to_expr(L, 3), lua_isnil(L, 3) ? none_expr() : some_expr(to_expr(L, 3)),
terminator, terminator,
lua_toboolean(L, 5), lua_toboolean(L, 5),
rbp)); rbp));
@ -540,7 +542,7 @@ static int rec(lua_State * L) {
} }
static int initial(lua_State * L) { static int initial(lua_State * L) {
check_action(L, 1, { action_kind::Exprs }); check_action(L, 1, { action_kind::Exprs });
return push_expr(L, to_notation_action(L, 1).get_initial()); return push_optional_expr(L, to_notation_action(L, 1).get_initial());
} }
static int is_fold_right(lua_State * L) { static int is_fold_right(lua_State * L) {
check_action(L, 1, { action_kind::Exprs }); check_action(L, 1, { action_kind::Exprs });

View file

@ -60,7 +60,7 @@ public:
friend void initialize_parse_table(); friend void initialize_parse_table();
friend action mk_expr_action(unsigned rbp); friend action mk_expr_action(unsigned rbp);
friend action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, friend action mk_exprs_action(name const & sep, expr const & rec, optional<expr> const & ini,
optional<name> const & terminator, bool right, unsigned rbp); optional<name> const & terminator, bool right, unsigned rbp);
friend action mk_scoped_expr_action(expr const & rec, unsigned rbp, bool lambda); friend action mk_scoped_expr_action(expr const & rec, unsigned rbp, bool lambda);
friend action mk_ext_action_core(parse_fn const & fn); friend action mk_ext_action_core(parse_fn const & fn);
@ -71,7 +71,7 @@ public:
unsigned rbp() const; unsigned rbp() const;
name const & get_sep() const; name const & get_sep() const;
expr const & get_rec() const; expr const & get_rec() const;
expr const & get_initial() const; optional<expr> const & get_initial() const;
optional<name> const & get_terminator() const; optional<name> const & get_terminator() const;
bool is_fold_right() const; bool is_fold_right() const;
bool use_lambda_abstraction() const; bool use_lambda_abstraction() const;
@ -89,7 +89,7 @@ inline bool operator!=(action const & a1, action const & a2) { return !a1.is_equ
action mk_skip_action(); action mk_skip_action();
action mk_expr_action(unsigned rbp = 0); action mk_expr_action(unsigned rbp = 0);
action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, optional<name> const & terminator, bool right, action mk_exprs_action(name const & sep, expr const & rec, optional<expr> const & ini, optional<name> const & terminator, bool right,
unsigned rbp = 0); unsigned rbp = 0);
action mk_binder_action(); action mk_binder_action();
action mk_binders_action(); action mk_binders_action();

View file

@ -903,9 +903,15 @@ expr parser::parse_notation(parse_table t, expr * left) {
throw parser_error(sstream() << "invalid composite expression, '" << *terminator << "' expected", pos()); throw parser_error(sstream() << "invalid composite expression, '" << *terminator << "' expected", pos());
} }
} }
expr r = instantiate_rev(copy_with_new_pos(a.get_initial(), p), args.size(), args.data());
expr rec = copy_with_new_pos(a.get_rec(), p); expr rec = copy_with_new_pos(a.get_rec(), p);
expr r;
if (a.is_fold_right()) { if (a.is_fold_right()) {
if (a.get_initial()) {
r = instantiate_rev(copy_with_new_pos(*a.get_initial(), p), args.size(), args.data());
} else {
r = r_args.back();
r_args.pop_back();
}
unsigned i = r_args.size(); unsigned i = r_args.size();
while (i > 0) { while (i > 0) {
--i; --i;
@ -915,7 +921,14 @@ expr parser::parse_notation(parse_table t, expr * left) {
args.pop_back(); args.pop_back(); args.pop_back(); args.pop_back();
} }
} else { } else {
for (unsigned i = 0; i < r_args.size(); i++) { unsigned fidx = 0;
if (a.get_initial()) {
r = instantiate_rev(copy_with_new_pos(*a.get_initial(), p), args.size(), args.data());
} else {
r = r_args[0];
fidx++;
}
for (unsigned i = fidx; i < r_args.size(); i++) {
args.push_back(r_args[i]); args.push_back(r_args[i]);
args.push_back(r); args.push_back(r);
r = instantiate_rev(rec, args.size(), args.data()); r = instantiate_rev(rec, args.size(), args.data());

View file

@ -128,7 +128,12 @@ serializer & operator<<(serializer & s, action const & a) {
s << a.rbp(); s << a.rbp();
break; break;
case action_kind::Exprs: case action_kind::Exprs:
s << a.get_sep() << a.get_rec() << a.get_initial() << a.is_fold_right() << a.rbp(); s << a.get_sep() << a.get_rec();
if (a.get_initial())
s << true << *a.get_initial();
else
s << false;
s << a.is_fold_right() << a.rbp();
if (auto t = a.get_terminator()) { if (auto t = a.get_terminator()) {
s << true << *t; s << true << *t;
} else { } else {
@ -161,8 +166,12 @@ action read_action(deserializer & d) {
d >> rbp; d >> rbp;
return notation::mk_expr_action(rbp); return notation::mk_expr_action(rbp);
case action_kind::Exprs: { case action_kind::Exprs: {
name sep; expr rec, ini; bool is_fold_right; name sep; expr rec; optional<expr> ini; bool is_fold_right;
d >> sep >> rec >> ini >> is_fold_right >> rbp; d >> sep >> rec;
if (d.read_bool()) {
ini = read_expr(d);
}
d >> is_fold_right >> rbp;
optional<name> terminator; optional<name> terminator;
if (d.read_bool()) if (d.read_bool())
terminator = read_name(d); terminator = read_name(d);

View file

@ -808,7 +808,7 @@ auto pretty_fn::pp_notation(notation_entry const & entry, buffer<optional<expr>>
expr e = *args.back(); expr e = *args.back();
args.pop_back(); args.pop_back();
expr const & rec = a.get_rec(); expr const & rec = a.get_rec();
expr const & ini = a.get_initial(); optional<expr> const & ini = a.get_initial();
buffer<expr> rec_args; buffer<expr> rec_args;
bool matched_once = false; bool matched_once = false;
while (true) { while (true) {
@ -818,7 +818,6 @@ auto pretty_fn::pp_notation(notation_entry const & entry, buffer<optional<expr>>
args.pop_back(); args.pop_back();
break; break;
} }
matched_once = true;
optional<expr> new_e = args.back(); optional<expr> new_e = args.back();
args.pop_back(); args.pop_back();
optional<expr> rec_arg = args.back(); optional<expr> rec_arg = args.back();
@ -827,11 +826,16 @@ auto pretty_fn::pp_notation(notation_entry const & entry, buffer<optional<expr>>
return optional<result>(); return optional<result>();
rec_args.push_back(*rec_arg); rec_args.push_back(*rec_arg);
e = *new_e; e = *new_e;
matched_once = true;
} }
if (!matched_once) if (!matched_once)
return optional<result>(); return optional<result>();
if (!match(ini, e, args)) if (ini) {
if (!match(*ini, e, args))
return optional<result>(); return optional<result>();
} else {
rec_args.push_back(e);
}
if (!a.is_fold_right()) if (!a.is_fold_right())
std::reverse(rec_args.begin(), rec_args.end()); std::reverse(rec_args.begin(), rec_args.end());
unsigned curr_lbp = token_lbp; unsigned curr_lbp = token_lbp;
@ -847,9 +851,14 @@ auto pretty_fn::pp_notation(notation_entry const & entry, buffer<optional<expr>>
while (i > 0) { while (i > 0) {
--i; --i;
result arg_res = pp_notation_child(rec_args[i], curr_lbp, a.rbp()); result arg_res = pp_notation_child(rec_args[i], curr_lbp, a.rbp());
if (i == 0) if (i == 0) {
sep_fmt = format(tk); if (add_extra_space(tk))
curr = format(tk) + space() + arg_res.fmt() + curr;
else
curr = format(tk) + arg_res.fmt() + curr;
} else {
curr = sep_fmt + space() + arg_res.fmt() + curr; curr = sep_fmt + space() + arg_res.fmt() + curr;
}
if (i > 0 && add_extra_space(a.get_sep())) if (i > 0 && add_extra_space(a.get_sep()))
curr = space() + curr; curr = space() + curr;
curr_lbp = sep_lbp; curr_lbp = sep_lbp;

31
tests/lean/fold.lean Normal file
View file

@ -0,0 +1,31 @@
import data.prod data.num
variables a b c : num
context
notation `(` t:(foldr `,` (e r, prod.mk e r)) `)` := t
check (a, false, b, true, c)
set_option pp.notation false
check (a, false, b, true, c)
end
context
notation `(` t:(foldr `,` (e r, prod.mk r e)) `)` := t
check (a, false, b, true, c)
set_option pp.notation false
check (a, false, b, true, c)
end
context
notation `(` t:(foldl `,` (e r, prod.mk r e)) `)` := t
check (a, false, b, true, c)
set_option pp.notation false
check (a, false, b, true, c)
end
context
notation `(` t:(foldl `,` (e r, prod.mk e r)) `)` := t
check (a, false, b, true, c)
set_option pp.notation false
check (a, false, b, true, c)
end

View file

@ -0,0 +1,8 @@
(a, false, b, true, c) : prod num (prod Prop (prod num (prod Prop num)))
prod.mk a (prod.mk false (prod.mk b (prod.mk true c))) : prod num (prod Prop (prod num (prod Prop num)))
(a, false, b, true, c) : prod (prod (prod (prod num Prop) num) Prop) num
prod.mk (prod.mk (prod.mk (prod.mk c true) b) false) a : prod (prod (prod (prod num Prop) num) Prop) num
(a, false, b, true, c) : prod (prod (prod (prod num Prop) num) Prop) num
prod.mk (prod.mk (prod.mk (prod.mk a false) b) true) c : prod (prod (prod (prod num Prop) num) Prop) num
(a, false, b, true, c) : prod num (prod Prop (prod num (prod Prop num)))
prod.mk c (prod.mk true (prod.mk b (prod.mk false a))) : prod num (prod Prop (prod num (prod Prop num)))