feat(frontends/lean): allow user to provide a terminator for 'foldr' and 'foldl'

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-07-30 15:04:44 -07:00
parent 5238da9ac7
commit 9637ceb86e
6 changed files with 81 additions and 27 deletions

View file

@ -240,8 +240,11 @@ static action parse_action(parser & p, name const & prev_token, unsigned default
locals.pop_back();
}
expr ini = parse_notation_expr(p, locals);
optional<name> terminator;
if (!p.curr_is_token(g_rparen))
terminator = parse_quoted_symbol_or_token(p, new_tokens);
p.check_token_next(g_rparen, "invalid fold notation argument, ')' expected");
return mk_exprs_action(sep, rec, ini, is_fold_right, prec ? *prec : 0);
return mk_exprs_action(sep, rec, ini, terminator, is_fold_right, prec ? *prec : 0);
} else if (p.curr_is_token_or_id(g_scoped)) {
p.next();
auto prec = parse_optional_precedence(p);

View file

@ -33,15 +33,15 @@ struct expr_action_cell : public action_cell {
};
struct exprs_action_cell : public expr_action_cell {
name m_token_sep;
expr m_rec;
expr m_ini;
bool m_fold_right;
exprs_action_cell(name const & sep, expr const & rec, expr const & ini, bool right,
unsigned rbp):
name m_token_sep;
expr m_rec;
expr m_ini;
optional<name> m_terminator;
bool m_fold_right;
exprs_action_cell(name const & sep, expr const & rec, expr const & ini,
optional<name> const & terminator, bool right, unsigned rbp):
expr_action_cell(action_kind::Exprs, rbp),
m_token_sep(sep), m_rec(rec), m_ini(ini), m_fold_right(right) {}
m_token_sep(sep), m_rec(rec), m_ini(ini), m_terminator(terminator), m_fold_right(right) {}
};
struct scoped_expr_action_cell : public expr_action_cell {
@ -95,6 +95,7 @@ ext_lua_action_cell * to_ext_lua_action(action_cell * c) {
}
unsigned action::rbp() const { return to_expr_action(m_ptr)->m_rbp; }
name const & action::get_sep() const { return to_exprs_action(m_ptr)->m_token_sep; }
optional<name> const & action::get_terminator() const { return to_exprs_action(m_ptr)->m_terminator; }
expr const & action::get_rec() const {
if (kind() == action_kind::ScopedExpr)
return to_scoped_expr_action(m_ptr)->m_rec;
@ -123,6 +124,7 @@ bool action::is_equal(action const & a) const {
rbp() == a.rbp() &&
get_rec() == a.get_rec() &&
get_initial() == a.get_initial() &&
get_terminator() == a.get_terminator() &&
is_fold_right() == a.is_fold_right();
case action_kind::ScopedExpr:
return
@ -140,8 +142,13 @@ void action::display(std::ostream & out) const {
case action_kind::LuaExt: out << "luaext"; break;
case action_kind::Expr: out << rbp(); break;
case action_kind::Exprs:
out << "(fold" << (is_fold_right() ? "r" : "l") << " "
<< rbp() << " " << get_rec() << " " << get_initial() << ")";
out << "(fold" << (is_fold_right() ? "r" : "l");
if (get_terminator())
out << "*";
out << " " << rbp() << " " << get_rec() << " " << get_initial();
if (get_terminator())
out << *get_terminator();
out << ")";
break;
case action_kind::ScopedExpr:
out << "(scoped " << rbp() << " " << get_rec() << ")";
@ -149,7 +156,6 @@ void action::display(std::ostream & out) const {
}
}
void action_cell::dealloc() {
switch (m_kind) {
case action_kind::Expr: delete(to_expr_action(this)); break;
@ -177,11 +183,11 @@ action mk_binders_action() {
return *r;
}
action mk_expr_action(unsigned rbp) { return action(new expr_action_cell(rbp)); }
action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, bool right, unsigned rbp) {
action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, optional<name> const & terminator, bool right, unsigned rbp) {
if (get_free_var_range(rec) > 2)
throw exception("invalid notation, the expression used to combine a sequence of expressions "
"must not contain free variables with de Bruijn indices greater than 1");
return action(new exprs_action_cell(sep, rec, ini, right, rbp));
return action(new exprs_action_cell(sep, rec, ini, terminator, right, rbp));
}
action mk_scoped_expr_action(expr const & rec, unsigned rb, bool lambda) {
return action(new scoped_expr_action_cell(rec, rb, lambda));
@ -195,7 +201,8 @@ action replace(action const & a, std::function<expr(expr const &)> const & f) {
case action_kind::Ext: case action_kind::LuaExt: case action_kind::Expr:
return a;
case action_kind::Exprs:
return mk_exprs_action(a.get_sep(), f(a.get_rec()), f(a.get_initial()), a.is_fold_right(), a.rbp());
return mk_exprs_action(a.get_sep(), f(a.get_rec()), f(a.get_initial()), a.get_terminator(),
a.is_fold_right(), a.rbp());
case action_kind::ScopedExpr:
return mk_scoped_expr_action(f(a.get_rec()), a.rbp(), a.use_lambda_abstraction());
}
@ -354,11 +361,15 @@ static int mk_expr_action(lua_State * L) {
}
static int mk_exprs_action(lua_State * L) {
int nargs = lua_gettop(L);
unsigned rbp = nargs <= 4 ? 0 : lua_tonumber(L, 5);
unsigned rbp = nargs <= 5 ? 0 : lua_tonumber(L, 6);
optional<name> terminator;
if (nargs >= 4) terminator = to_optional_name(L, 4);
return push_notation_action(L, mk_exprs_action(to_name_ext(L, 1),
to_expr(L, 2),
to_expr(L, 3),
lua_toboolean(L, 4), rbp));
to_expr(L, 2),
to_expr(L, 3),
terminator,
lua_toboolean(L, 5),
rbp));
}
static int mk_scoped_expr_action(lua_State * L) {
int nargs = lua_gettop(L);

View file

@ -59,7 +59,7 @@ public:
friend action mk_skip_action();
friend action mk_expr_action(unsigned rbp);
friend action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, bool right, unsigned rbp);
friend action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, optional<name> const & terminator, bool right, unsigned rbp);
friend action mk_binder_action();
friend action mk_binders_action();
friend action mk_scoped_expr_action(expr const & rec, unsigned rbp, bool lambda);
@ -71,6 +71,7 @@ public:
name const & get_sep() const;
expr const & get_rec() const;
expr const & get_initial() const;
optional<name> const & get_terminator() const;
bool is_fold_right() const;
bool use_lambda_abstraction() const;
parse_fn const & get_parse_fn() const;
@ -82,7 +83,7 @@ public:
action mk_skip_action();
action mk_expr_action(unsigned rbp = 0);
action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, bool right, unsigned rbp = 0);
action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, optional<name> const & terminator, bool right, unsigned rbp = 0);
action mk_binder_action();
action mk_binders_action();
action mk_scoped_expr_action(expr const & rec, unsigned rbp = 0, bool lambda = true);

View file

@ -708,10 +708,20 @@ expr parser::parse_notation(parse_table t, expr * left) {
break;
case notation::action_kind::Exprs: {
buffer<expr> r_args;
r_args.push_back(parse_expr(a.rbp()));
while (curr_is_token(a.get_sep())) {
next();
auto terminator = a.get_terminator();
if (!terminator || !curr_is_token(*terminator)) {
r_args.push_back(parse_expr(a.rbp()));
while (curr_is_token(a.get_sep())) {
next();
r_args.push_back(parse_expr(a.rbp()));
}
}
if (terminator) {
if (curr_is_token(*terminator)) {
next();
} else {
throw parser_error(sstream() << "invalid composite expression, '" << *terminator << "' expected", pos());
}
}
expr r = instantiate_rev(copy_with_new_pos(a.get_initial(), p), args.size(), args.data());
expr rec = copy_with_new_pos(a.get_rec(), p);
@ -955,11 +965,10 @@ unsigned parser::curr_lbp() const {
case scanner::token_kind::Keyword:
return get_token_info().precedence();
case scanner::token_kind::CommandKeyword: case scanner::token_kind::Eof:
case scanner::token_kind::ScriptBlock:
case scanner::token_kind::ScriptBlock: case scanner::token_kind::QuotedSymbol:
return 0;
case scanner::token_kind::Identifier: case scanner::token_kind::Numeral:
case scanner::token_kind::Decimal: case scanner::token_kind::String:
case scanner::token_kind::QuotedSymbol:
return std::numeric_limits<unsigned>::max();
}
lean_unreachable(); // LCOV_EXCL_LINE

View file

@ -79,6 +79,11 @@ serializer & operator<<(serializer & s, action const & a) {
break;
case action_kind::Exprs:
s << a.get_sep() << a.get_rec() << a.get_initial() << a.is_fold_right() << a.rbp();
if (auto t = a.get_terminator()) {
s << true << *t;
} else {
s << false;
}
break;
case action_kind::ScopedExpr:
s << a.get_rec() << a.rbp() << a.use_lambda_abstraction();
@ -108,7 +113,10 @@ action read_action(deserializer & d) {
case action_kind::Exprs: {
name sep; expr rec, ini; bool is_fold_right;
d >> sep >> rec >> ini >> is_fold_right >> rbp;
return notation::mk_exprs_action(sep, rec, ini, is_fold_right, rbp);
optional<name> terminator;
if (d.read_bool())
terminator = read_name(d);
return notation::mk_exprs_action(sep, rec, ini, terminator, is_fold_right, rbp);
}
case action_kind::ScopedExpr: {
expr rec; bool use_lambda_abstraction;

View file

@ -0,0 +1,22 @@
import bool
using bool
variable list : Type.{1}
variable nil : list
variable cons : bool → list → list
infixr `::`:65 := cons
notation `[` l:(foldr `,` (h t, cons h t) nil `]`) := l
check []
check [tt]
check [tt, ff]
check [tt, ff, ff]
check tt :: ff :: [tt, ff, ff]
check tt :: []
variables a b c : bool
check a :: b :: nil
check [a, b]
check [a, b, c]
check []