diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index 9f3d2e788..e0c182e43 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -17,6 +17,7 @@ Author: Leonardo de Moura #include #include #include +#include "util/luaref.h" #include "util/scoped_map.h" #include "util/exception.h" #include "util/sstream.h" @@ -128,6 +129,11 @@ static unsigned g_level_cup_prec = 5; // are syntax sugar for (Pi (_ : A), B) static name g_unused = name::mk_internal_unique_name(); +enum class macro_arg_kind { Expr, Exprs, Bindings, Id, Comma, Assign }; +typedef std::pair, luaref> macro; +typedef name_map macros; +macros & get_macros(lua_State * L); + /** \brief Actual implementation for the parser functional object @@ -150,6 +156,7 @@ class parser::imp { scanner m_scanner; frontend_elaborator m_elaborator; type_inferer m_type_inferer; + macros const * m_macros; scanner::token m_curr; bool m_use_exceptions; bool m_interactive; @@ -839,6 +846,97 @@ class parser::imp { } } + bool is_curr_begin_expr() const { + switch (curr()) { + case scanner::token::RightParen: + case scanner::token::RightCurlyBracket: + case scanner::token::Colon: + case scanner::token::Comma: + case scanner::token::Period: + case scanner::token::CommandId: + case scanner::token::Eof: + case scanner::token::ScriptBlock: + return false; + default: + return true; + } + } + + /** + \brief Parse a macro implemented in Lua + */ + expr parse_macro(lua_State * L, list const & args, unsigned num_args, pos_info const & p) { + if (args) { + auto k = head(args); + switch (k) { + case macro_arg_kind::Expr: + push_expr(L, parse_expr()); + return parse_macro(L, tail(args), num_args + 1, p); + case macro_arg_kind::Exprs: { + lua_newtable(L); + int i = 1; + while (is_curr_begin_expr()) { + push_expr(L, parse_expr(g_app_precedence)); + lua_rawseti(L, -2, i); + i = i + 1; + } + return parse_macro(L, tail(args), num_args + 1, p); + } + case macro_arg_kind::Bindings: { + mk_scope scope(*this); + bindings_buffer bindings; + parse_expr_bindings(bindings); + lua_newtable(L); + int i = 1; + for (auto const & b : bindings) { + lua_newtable(L); + push_name(L, std::get<1>(b)); + lua_rawseti(L, -2, 1); + push_expr(L, std::get<2>(b)); + lua_rawseti(L, -2, 2); + lua_rawseti(L, -2, i); + i = i + 1; + } + return parse_macro(L, tail(args), num_args + 1, p); + } + case macro_arg_kind::Comma: + check_comma_next("invalid macro, ',' expected"); + return parse_macro(L, tail(args), num_args, p); + case macro_arg_kind::Assign: + check_comma_next("invalid macro, ':=' expected"); + return parse_macro(L, tail(args), num_args, p); + case macro_arg_kind::Id: + push_name(L, curr_name()); + next(); + return parse_macro(L, tail(args), num_args + 1, p); + } + lean_unreachable(); + } else { + // All arguments have been parsed, then call Lua procedure proc. + m_last_script_pos = p; + pcall(L, num_args, 1, 0); + if (is_expr(L, -1)) { + expr r = to_expr(L, -1); + lua_pop(L, 1); + return save(r, p); + } else { + lua_pop(L, 1); + throw parser_error("failed to execute macro", p); + } + } + } + + expr parse_macro(name const & id, pos_info const & p) { + lean_assert(m_macros && m_macros->find(id) != m_macros->end()); + auto m = m_macros->find(id)->second; + list args = m.first; + luaref proc = m.second; + return m_script_state->apply([&](lua_State * L) { + proc.push(); + return parse_macro(L, args, 0, p); + }); + } + /** \brief Parse an identifier that has a "null denotation" (See paper: "Top down operator precedence"). A nud identifier is a @@ -854,6 +952,8 @@ class parser::imp { auto it = m_local_decls.find(id); if (it != m_local_decls.end()) { return save(mk_var(m_num_local_decls - it->second - 1), p); + } else if (m_macros && m_macros->find(id) != m_macros->end()) { + return parse_macro(id, p); } else { operator_info op = find_nud(m_env, id); if (op) { @@ -2226,6 +2326,13 @@ public: m_use_exceptions(use_exceptions), m_interactive(interactive) { m_script_state = S; + if (m_script_state) { + m_script_state->apply([&](lua_State * L) { + m_macros = &get_macros(L); + }); + } else { + m_macros = nullptr; + } updt_options(); m_found_errors = false; m_num_local_decls = 0; @@ -2357,4 +2464,61 @@ expr parse_expr(environment const & env, io_state & ios, std::istream & in, scri ios = p.get_io_state(); return r; } + +static char g_parser_macros_key; +DECL_UDATA(macros) + +void init_macros(lua_State * L) { + lua_pushlightuserdata(L, static_cast(&g_parser_macros_key)); + push_macros(L, macros()); + lua_settable(L, LUA_REGISTRYINDEX); +} + +macros & get_macros(lua_State * L) { + lua_pushlightuserdata(L, static_cast(&g_parser_macros_key)); + lua_gettable(L, LUA_REGISTRYINDEX); + lean_assert(is_macros(L, -1)); + macros & r = to_macros(L, -1); + lua_pop(L, 1); + return r; +} + +int mk_macro(lua_State * L) { + name macro_name = to_name_ext(L, 1); + luaL_checktype(L, 3, LUA_TFUNCTION); // user-fun + buffer arg_kind_buffer; + int n = objlen(L, 2); + for (int i = 1; i <= n; i++) { + lua_rawgeti(L, 2, i); + arg_kind_buffer.push_back(static_cast(luaL_checkinteger(L, -1))); + lua_pop(L, 1); + } + list arg_kinds = to_list(arg_kind_buffer.begin(), arg_kind_buffer.end()); + get_macros(L).insert(mk_pair(macro_name, macro(arg_kinds, luaref(L, 3)))); + return 0; +} + +static const struct luaL_Reg macros_m[] = { + {"__gc", macros_gc}, // never throws + {0, 0} +}; + +void open_macros(lua_State * L) { + luaL_newmetatable(L, macros_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, macros_m, 0); + SET_GLOBAL_FUN(macros_pred, "is_macros"); + init_macros(L); + SET_GLOBAL_FUN(mk_macro, "macro"); + + lua_newtable(L); + SET_ENUM("Expr", macro_arg_kind::Expr); + SET_ENUM("Exprs", macro_arg_kind::Exprs); + SET_ENUM("Bindings", macro_arg_kind::Bindings); + SET_ENUM("Id", macro_arg_kind::Id); + SET_ENUM("Comma", macro_arg_kind::Comma); + SET_ENUM("Assign", macro_arg_kind::Assign); + lua_setglobal(L, "macro_arg"); +} } diff --git a/src/frontends/lean/parser.h b/src/frontends/lean/parser.h index f60920801..97037eb7b 100644 --- a/src/frontends/lean/parser.h +++ b/src/frontends/lean/parser.h @@ -6,6 +6,7 @@ Author: Leonardo de Moura */ #pragma once #include +#include "util/lua.h" #include "kernel/environment.h" #include "library/io_state.h" @@ -45,4 +46,6 @@ public: bool parse_commands(environment const & env, io_state & st, std::istream & in, script_state * S = nullptr, bool use_exceptions = true, bool interactive = false); expr parse_expr(environment const & env, io_state & st, std::istream & in, script_state * S = nullptr, bool use_exceptions = true); + +void open_macros(lua_State * L); } diff --git a/src/frontends/lean/register_module.cpp b/src/frontends/lean/register_module.cpp index 54ea199bc..270718120 100644 --- a/src/frontends/lean/register_module.cpp +++ b/src/frontends/lean/register_module.cpp @@ -136,6 +136,7 @@ static int mk_lean_formatter(lua_State * L) { } void open_frontend_lean(lua_State * L) { + open_macros(L); SET_GLOBAL_FUN(mk_environment, "environment"); SET_GLOBAL_FUN(mk_lean_formatter, "lean_formatter"); SET_GLOBAL_FUN(parse_lean_expr, "parse_lean"); diff --git a/tests/lean/lua18.lean b/tests/lean/lua18.lean new file mode 100644 index 000000000..e11aa271c --- /dev/null +++ b/tests/lean/lua18.lean @@ -0,0 +1,24 @@ +(** +macro("MyMacro", { macro_arg.Expr, macro_arg.Comma, macro_arg.Expr }, + function (e1, e2) + return Const({"Int", "add"})(e1, e2) + end) +macro("Sum", { macro_arg.Exprs }, + function (es) + if #es == 0 then + return iVal(0) + end + local r = es[1] + local add = Const({"Int", "add"}) + for i = 2, #es do + r = add(r, es[i]) + end + return r + end) +**) + +Show (MyMacro 10, 20) + 20 +Show (Sum) +Show Sum 10 20 30 40 +Show fun x, Sum x 10 x 20 +Eval (fun x, Sum x 10 x 20) 100 diff --git a/tests/lean/lua18.lean.expected.out b/tests/lean/lua18.lean.expected.out new file mode 100644 index 000000000..5c9da949e --- /dev/null +++ b/tests/lean/lua18.lean.expected.out @@ -0,0 +1,7 @@ + Set: pp::colors + Set: pp::unicode +10 + 20 + 20 +0 +10 + 20 + 30 + 40 +λ x : ℤ, x + 10 + x + 20 +230