diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index d86197e58..299fc73a8 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -286,7 +286,7 @@ static std::pair get_expr_pair_from_table(lua_State * L, int t, int lua_pushinteger(L, i); lua_gettable(L, -2); // now table {ai, bi} is on the top if (!lua_istable(L, -1) || objlen(L, -1) != 2) - throw exception("arg #1 must be of the form '{{expr, expr}, ...}'"); + throw exception(sstream() << "arg #" << t << " must be of the form '{{expr, expr}, ...}'"); expr ai = get_expr_from_table(L, -1, 1); expr bi = get_expr_from_table(L, -1, 2); lua_pop(L, 2); // pop table {ai, bi} and t from stack @@ -334,6 +334,51 @@ static int expr_mk_metavar(lua_State * L) { return push_expr(L, mk_metavar(to_na static int expr_mk_local(lua_State * L) { return push_expr(L, mk_local(to_name_ext(L, 1), to_expr(L, 2))); } static int expr_get_kind(lua_State * L) { return push_integer(L, static_cast(to_expr(L, 1).kind())); } +// t is a table of pairs {{a1, b1, c1}, ..., {ak, bk, ck}} +// ai, bi and ci are expressions +static std::tuple get_expr_triple_from_table(lua_State * L, int t, int i) { + lua_pushvalue(L, t); // push table on the top + lua_pushinteger(L, i); + lua_gettable(L, -2); // now table {ai, bi, ci} is on the top + if (!lua_istable(L, -1) || objlen(L, -1) != 3) + throw exception(sstream() << "arg #" << t << " must be of the form '{{expr, expr, expr}, ...}'"); + expr ai = get_expr_from_table(L, -1, 1); + expr bi = get_expr_from_table(L, -1, 2); + expr ci = get_expr_from_table(L, -1, 3); + lua_pop(L, 2); // pop table {ai, bi, ci} and t from stack + return std::make_tuple(ai, bi, ci); +} + +static int expr_let(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs < 2) + throw exception("function must have at least 2 arguments"); + if (nargs == 2) { + if (!lua_istable(L, 1)) + throw exception("function expects arg #1 to be of the form '{{expr, expr, expr}, ...}'"); + int len = objlen(L, 1); + if (len == 0) + throw exception("function expects arg #1 to be a non-empty table"); + expr r = to_expr(L, 2); + for (int i = len; i >= 1; i--) { + auto p = get_expr_triple_from_table(L, 1, i); + r = Let(std::get<0>(p), std::get<1>(p), std::get<2>(p), r); + } + return push_expr(L, r); + } else { + if ((nargs - 1) % 3 != 0) + throw exception("function must have 3*n + 1 arguments"); + expr r = to_expr(L, nargs); + for (int i = nargs - 1; i >= 1; i-=3) { + if (is_expr(L, i - 2)) + r = Let(to_expr(L, i - 2), to_expr(L, i - 1), to_expr(L, i), r); + else + r = Let(to_name_ext(L, i - 2), to_expr(L, i - 1), to_expr(L, i), r); + } + return push_expr(L, r); + } +} + #define EXPR_PRED(P) static int expr_ ## P(lua_State * L) { return push_boolean(L, P(to_expr(L, 1))); } EXPR_PRED(is_constant) @@ -562,7 +607,7 @@ static void open_expr(lua_State * L) { SET_GLOBAL_FUN(expr_fun, "fun"); SET_GLOBAL_FUN(expr_fun, "Fun"); SET_GLOBAL_FUN(expr_pi, "Pi"); - SET_GLOBAL_FUN(expr_mk_let, "Let"); + SET_GLOBAL_FUN(expr_let, "Let"); SET_GLOBAL_FUN(expr_mk_sort, "mk_sort"); SET_GLOBAL_FUN(expr_mk_metavar, "mk_metavar"); SET_GLOBAL_FUN(expr_mk_local, "mk_local");