diff --git a/src/frontends/lean/parse_table.cpp b/src/frontends/lean/parse_table.cpp index d3ee22838..5877b9699 100644 --- a/src/frontends/lean/parse_table.cpp +++ b/src/frontends/lean/parse_table.cpp @@ -45,10 +45,11 @@ struct exprs_action_cell : public expr_action_cell { struct scoped_expr_action_cell : public expr_action_cell { expr m_rec; - - scoped_expr_action_cell(expr const & rec, unsigned rb): + bool m_lambda; + scoped_expr_action_cell(expr const & rec, unsigned rb, bool lambda): expr_action_cell(action_kind::ScopedExpr, rb), - m_rec(rec) {} + m_rec(rec), + m_lambda(lambda) {} }; struct ext_action_cell : public action_cell { @@ -89,6 +90,7 @@ expr const & action::get_rec() const { else return to_exprs_action(m_ptr)->m_rec; } +bool action::use_lambda_abstraction() const { return to_scoped_expr_action(m_ptr)->m_lambda; } expr const & action::get_initial() const { return to_exprs_action(m_ptr)->m_ini; } bool action::is_fold_right() const { return to_exprs_action(m_ptr)->m_fold_right; } parse_fn const & action::get_parse_fn() const { return to_ext_action(m_ptr)->m_parse_fn; } @@ -140,8 +142,8 @@ action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, boo "must not contain free variables with de Bruijn indices greater than 1"); return action(new exprs_action_cell(sep, rec, ini, right, rbp)); } -action mk_scoped_expr_action(expr const & rec, unsigned rb) { - return action(new scoped_expr_action_cell(rec, rb)); +action mk_scoped_expr_action(expr const & rec, unsigned rb, bool lambda) { + return action(new scoped_expr_action_cell(rec, rb, lambda)); } action mk_ext_parse_action(parse_fn const & fn) { return action(new ext_action_cell(fn)); } @@ -283,7 +285,8 @@ static int mk_exprs_action(lua_State * L) { static int mk_scoped_expr_action(lua_State * L) { int nargs = lua_gettop(L); unsigned rbp = nargs <= 1 ? 0 : lua_tonumber(L, 2); - return push_notation_action(L, mk_scoped_expr_action(to_expr(L, 1), rbp)); + bool lambda = (nargs <= 2) || lua_toboolean(L, 3); + return push_notation_action(L, mk_scoped_expr_action(to_expr(L, 1), rbp, lambda)); } static int is_compatible(lua_State * L) { return push_boolean(L, to_notation_action(L, 1).is_compatible(to_notation_action(L, 2))); @@ -314,17 +317,22 @@ static int is_fold_right(lua_State * L) { check_action(L, 1, { action_kind::Exprs }); return push_boolean(L, to_notation_action(L, 1).is_fold_right()); } +static int use_lambda_abstraction(lua_State * L) { + check_action(L, 1, { action_kind::ScopedExpr }); + return push_boolean(L, to_notation_action(L, 1).use_lambda_abstraction()); +} static const struct luaL_Reg notation_action_m[] = { - {"__gc", notation_action_gc}, - {"is_compatible", safe_function}, - {"kind", safe_function}, - {"rbp", safe_function}, - {"sep", safe_function}, - {"separator", safe_function}, - {"rec", safe_function}, - {"initial", safe_function}, - {"is_fold_right", safe_function}, + {"__gc", notation_action_gc}, + {"is_compatible", safe_function}, + {"kind", safe_function}, + {"rbp", safe_function}, + {"sep", safe_function}, + {"separator", safe_function}, + {"rec", safe_function}, + {"initial", safe_function}, + {"is_fold_right", safe_function}, + {"use_lambda_abstraction", safe_function}, {0, 0} }; diff --git a/src/frontends/lean/parse_table.h b/src/frontends/lean/parse_table.h index bc00227ba..682d4d10c 100644 --- a/src/frontends/lean/parse_table.h +++ b/src/frontends/lean/parse_table.h @@ -60,7 +60,7 @@ public: friend action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, bool right, unsigned rbp); friend action mk_binder_action(); friend action mk_binders_action(); - friend action mk_scoped_expr_action(expr const & rec, unsigned rbp); + friend action mk_scoped_expr_action(expr const & rec, unsigned rbp, bool lambda); friend action mk_ext_parse_action(parse_fn const & fn); action_kind kind() const; @@ -69,6 +69,7 @@ public: expr const & get_rec() const; expr const & get_initial() const; bool is_fold_right() const; + bool use_lambda_abstraction() const; parse_fn const & get_parse_fn() const; bool is_compatible(action const & a) const; @@ -79,7 +80,7 @@ action mk_expr_action(unsigned rbp = 0); action mk_exprs_action(name const & sep, expr const & rec, expr const & ini, bool right, unsigned rbp = 0); action mk_binder_action(); action mk_binders_action(); -action mk_scoped_expr_action(expr const & rec, unsigned rbp = 0); +action mk_scoped_expr_action(expr const & rec, unsigned rbp = 0, bool lambda = true); action mk_proc_action(parse_fn const & fn); class transition { diff --git a/tests/lua/parse_table.lua b/tests/lua/parse_table.lua index c5fde713d..676cccd97 100644 --- a/tests/lua/parse_table.lua +++ b/tests/lua/parse_table.lua @@ -109,3 +109,11 @@ assert(parse_table_size(p) == 4) local p3 = parse_table() check_error(function() p:merge(p3) end) + +local a = scoped_expr_notation_action(Var(0), 10) +assert(a:use_lambda_abstraction()) +local a = scoped_expr_notation_action(Var(0), 10, false) +assert(not a:use_lambda_abstraction()) +local a = scoped_expr_notation_action(Var(0), 10, true) +assert(a:use_lambda_abstraction()) +assert(a:rbp() == 10)