diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index fb0a943f1..ded166101 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -202,6 +202,13 @@ expr & to_binder(lua_State * L, int idx) { return r; } +expr & to_macro_app(lua_State * L, int idx) { + expr & r = to_expr(L, idx); + if (!is_macro(r)) + throw exception(sstream() << "arg #" << idx << " must be a macro application"); + return r; +} + static int expr_tostring(lua_State * L) { std::ostringstream out; formatter fmt = get_global_formatter(L); @@ -334,8 +341,7 @@ static int expr_fields(lua_State * L) { case expr_kind::Sort: return push_level(L, sort_level(e)); case expr_kind::Macro: - // TODO(Leo) - return 0; + return push_macro_definition(L, macro_def(e)); case expr_kind::App: push_expr(L, app_fn(e)); push_expr(L, app_arg(e)); return 2; case expr_kind::Lambda: @@ -441,6 +447,15 @@ static int expr_hash(lua_State * L) { return pushinteger(L, to_expr(L, 1).hash() static int expr_depth(lua_State * L) { return pushinteger(L, get_depth(to_expr(L, 1))); } static int expr_is_lt(lua_State * L) { return pushboolean(L, is_lt(to_expr(L, 1), to_expr(L, 2), false)); } +static int expr_mk_macro(lua_State * L) { + buffer args; + copy_lua_array(L, 2, args); + return push_expr(L, mk_macro(to_macro_definition(L, 1), args.size(), args.data())); +} + +static int macro_def(lua_State * L) { return push_macro_definition(L, macro_def(to_macro_app(L, 1))); } +static int macro_num_args(lua_State * L) { return pushinteger(L, macro_num_args(to_macro_app(L, 1))); } +static int macro_arg(lua_State * L) { return push_expr(L, macro_arg(to_macro_app(L, 1), pushinteger(L, 2))); } static const struct luaL_Reg expr_m[] = { {"__gc", expr_gc}, // never throws @@ -475,6 +490,9 @@ static const struct luaL_Reg expr_m[] = { {"binder_domain", safe_function}, {"binder_body", safe_function}, {"binder_info", safe_function}, + {"macro_def", safe_function}, + {"macro_num_args", safe_function}, + {"macro_arg", safe_function}, {"for_each", safe_function}, {"has_free_var", safe_function}, {"lift_free_vars", safe_function}, @@ -508,6 +526,7 @@ static void open_expr(lua_State * L) { SET_GLOBAL_FUN(expr_mk_pi, "mk_pi"); SET_GLOBAL_FUN(expr_mk_arrow, "mk_arrow"); SET_GLOBAL_FUN(expr_mk_let, "mk_let"); + SET_GLOBAL_FUN(expr_mk_macro, "mk_macro"); SET_GLOBAL_FUN(expr_fun, "fun"); SET_GLOBAL_FUN(expr_fun, "Fun"); SET_GLOBAL_FUN(expr_pi, "Pi"); @@ -536,6 +555,39 @@ static void open_expr(lua_State * L) { SET_ENUM("Macro", expr_kind::Macro); lua_setglobal(L, "expr_kind"); } +// macro_definition +DECL_UDATA(macro_definition) + +int macro_get_name(lua_State * L) { return push_name(L, to_macro_definition(L, 1).get_name()); } +int macro_trust_level(lua_State * L) { return pushinteger(L, to_macro_definition(L, 1).trust_level()); } +int macro_eq(lua_State * L) { return pushboolean(L, to_macro_definition(L, 1) == to_macro_definition(L, 2)); } +int macro_hash(lua_State * L) { return pushinteger(L, to_macro_definition(L, 1).hash()); } +static int macro_tostring(lua_State * L) { + std::ostringstream out; + formatter fmt = get_global_formatter(L); + options opts = get_global_options(L); + out << mk_pair(to_macro_definition(L, 1).pp(fmt, opts), opts); + return pushstring(L, out.str().c_str()); +} + +static const struct luaL_Reg macro_definition_m[] = { + {"__gc", macro_definition_gc}, // never throws + {"__tostring", safe_function}, + {"__eq", safe_function}, + {"hash", safe_function}, + {"trust_level", safe_function}, + {"name", safe_function}, + {0, 0} +}; + +void open_macro_definition(lua_State * L) { + luaL_newmetatable(L, macro_definition_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, macro_definition_m, 0); + + SET_GLOBAL_FUN(macro_definition_pred, "is_macro_definition"); +} // Formatter DECL_UDATA(formatter) @@ -1051,6 +1103,7 @@ void open_kernel_module(lua_State * L) { open_binder_info(L); open_expr(L); open_list_expr(L); + open_macro_definition(L); open_formatter(L); open_environment(L); open_io_state(L); diff --git a/src/library/kernel_bindings.h b/src/library/kernel_bindings.h index 2f4346987..90b820c0b 100644 --- a/src/library/kernel_bindings.h +++ b/src/library/kernel_bindings.h @@ -14,6 +14,7 @@ UDATA_DEFS(level) UDATA_DEFS(expr) UDATA_DEFS(formatter) UDATA_DEFS(definition) +UDATA_DEFS(macro_definition) UDATA_DEFS(environment) UDATA_DEFS(substitution) UDATA_DEFS(justification)