diff --git a/src/bindings/lua/context.cpp b/src/bindings/lua/context.cpp index d42191e15..051064930 100644 --- a/src/bindings/lua/context.cpp +++ b/src/bindings/lua/context.cpp @@ -39,9 +39,9 @@ static int context_entry_gc(lua_State * L) { static int mk_context_entry(lua_State * L) { int nargs = lua_gettop(L); if (nargs == 2) - return push_context_entry(L, context_entry(to_name_ext(L, 1), to_expr(L, 2))); + return push_context_entry(L, context_entry(to_name_ext(L, 1), to_nonnull_expr(L, 2))); else - return push_context_entry(L, context_entry(to_name_ext(L, 1), to_expr(L, 2), to_expr(L, 3))); + return push_context_entry(L, context_entry(to_name_ext(L, 1), to_nonnull_expr(L, 2), to_nonnull_expr(L, 3))); } static int context_entry_pred(lua_State * L) { @@ -100,9 +100,9 @@ static int mk_context(lua_State * L) { context_entry & e = to_context_entry(L, 2); return push_context(L, context(to_context(L, 1), e.get_name(), e.get_domain(), e.get_body())); } else if (nargs == 3) { - return push_context(L, context(to_context(L, 1), to_name_ext(L, 2), to_expr(L, 3))); + return push_context(L, context(to_context(L, 1), to_name_ext(L, 2), to_nonnull_expr(L, 3))); } else { - return push_context(L, context(to_context(L, 1), to_name_ext(L, 2), to_expr(L, 3), to_expr(L, 4))); + return push_context(L, context(to_context(L, 1), to_name_ext(L, 2), to_nonnull_expr(L, 3), to_nonnull_expr(L, 4))); } } diff --git a/src/bindings/lua/expr.cpp b/src/bindings/lua/expr.cpp index 4b4ebb24f..281441732 100644 --- a/src/bindings/lua/expr.cpp +++ b/src/bindings/lua/expr.cpp @@ -29,6 +29,13 @@ expr & to_expr(lua_State * L, int idx) { return *static_cast(luaL_checkudata(L, idx, expr_mt)); } +expr & to_nonnull_expr(lua_State * L, int idx) { + expr & r = to_expr(L, idx); + if (!r) + luaL_error(L, "non null Lean expression expected"); + return r; +} + int push_expr(lua_State * L, expr const & e) { void * mem = lua_newuserdata(L, sizeof(expr)); new (mem) expr(e); @@ -79,35 +86,35 @@ static int expr_mk_app(lua_State * L) { luaL_error(L, "application must have at least two arguments"); buffer args; for (int i = 1; i <= nargs; i++) - args.push_back(to_expr(L, i)); + args.push_back(to_nonnull_expr(L, i)); return push_expr(L, mk_app(args)); } static int expr_mk_eq(lua_State * L) { - return push_expr(L, mk_eq(to_expr(L, 1), to_expr(L, 2))); + return push_expr(L, mk_eq(to_nonnull_expr(L, 1), to_nonnull_expr(L, 2))); } static int expr_mk_lambda(lua_State * L) { - return push_expr(L, mk_lambda(to_name_ext(L, 1), to_expr(L, 2), to_expr(L, 3))); + return push_expr(L, mk_lambda(to_name_ext(L, 1), to_nonnull_expr(L, 2), to_nonnull_expr(L, 3))); } static int expr_mk_pi(lua_State * L) { - return push_expr(L, mk_lambda(to_name_ext(L, 1), to_expr(L, 2), to_expr(L, 3))); + return push_expr(L, mk_lambda(to_name_ext(L, 1), to_nonnull_expr(L, 2), to_nonnull_expr(L, 3))); } static int expr_mk_let(lua_State * L) { int nargs = lua_gettop(L); if (nargs == 3) - return push_expr(L, mk_let(to_name_ext(L, 1), expr(), to_expr(L, 2), to_expr(L, 3))); + return push_expr(L, mk_let(to_name_ext(L, 1), expr(), to_nonnull_expr(L, 2), to_nonnull_expr(L, 3))); else - return push_expr(L, mk_let(to_name_ext(L, 1), to_expr(L, 2), to_expr(L, 3), to_expr(L, 4))); + return push_expr(L, mk_let(to_name_ext(L, 1), to_nonnull_expr(L, 2), to_nonnull_expr(L, 3), to_nonnull_expr(L, 4))); } static expr get_expr_from_table(lua_State * L, int t, int i) { lua_pushvalue(L, t); // push table to the top lua_pushinteger(L, i); lua_gettable(L, -2); - expr r = to_expr(L, -1); + expr r = to_nonnull_expr(L, -1); lua_pop(L, 2); // remove table and value return r; } @@ -140,7 +147,7 @@ int expr_abst(lua_State * L, char const * fname) { int len = objlen(L, 1); if (len == 0) luaL_error(L, "Lean %s expects arg #1 to be non-empty table", fname); - expr r = to_expr(L, 2); + expr r = to_nonnull_expr(L, 2); for (int i = len; i >= 1; i--) { auto p = get_expr_pair_from_table(L, 1, i); r = F1(p.first, p.second, r); @@ -149,12 +156,12 @@ int expr_abst(lua_State * L, char const * fname) { } else { if (nargs % 2 == 0) luaL_error(L, "Lean %s must have an odd number of arguments", fname); - expr r = to_expr(L, nargs); + expr r = to_nonnull_expr(L, nargs); for (int i = nargs - 1; i >= 1; i-=2) { if (is_expr(L, i - 1)) - r = F1(to_expr(L, i - 1), to_expr(L, i), r); + r = F1(to_nonnull_expr(L, i - 1), to_nonnull_expr(L, i), r); else - r = F2(to_name_ext(L, i - 1), to_expr(L, i), r); + r = F2(to_name_ext(L, i - 1), to_nonnull_expr(L, i), r); } return push_expr(L, r); } diff --git a/src/bindings/lua/expr.h b/src/bindings/lua/expr.h index 23e281df8..7bd8beb78 100644 --- a/src/bindings/lua/expr.h +++ b/src/bindings/lua/expr.h @@ -11,5 +11,6 @@ class expr; void open_expr(lua_State * L); bool is_expr(lua_State * L, int idx); expr & to_expr(lua_State * L, int idx); +expr & to_nonnull_expr(lua_State * L, int idx); int push_expr(lua_State * L, expr const & o); } diff --git a/src/bindings/lua/local_context.cpp b/src/bindings/lua/local_context.cpp index 688a9f942..08d54a8f5 100644 --- a/src/bindings/lua/local_context.cpp +++ b/src/bindings/lua/local_context.cpp @@ -45,7 +45,7 @@ static int local_entry_mk_lift(lua_State * L) { } static int local_entry_mk_inst(lua_State * L) { - return push_local_entry(L, mk_inst(luaL_checkinteger(L, 1), to_expr(L, 2))); + return push_local_entry(L, mk_inst(luaL_checkinteger(L, 1), to_nonnull_expr(L, 2))); } static int local_entry_pred(lua_State * L) {