diff --git a/src/util/lua.cpp b/src/util/lua.cpp index 30da54929..80d920896 100644 --- a/src/util/lua.cpp +++ b/src/util/lua.cpp @@ -195,5 +195,13 @@ lua_migrate_fn get_migrate_fn(lua_State * L, int i) { } return nullptr; } + +void check_num_args(lua_State * L, int num) { + if (lua_gettop(L) != num) throw exception("incorrect number of arguments in function application"); +} + +void check_atmost_num_args(lua_State * L, int high) { + if (lua_gettop(L) > high) throw exception("too many arguments in function application"); +} } diff --git a/src/util/lua.h b/src/util/lua.h index a97081956..7eb60ed15 100644 --- a/src/util/lua.h +++ b/src/util/lua.h @@ -64,12 +64,9 @@ DECL_PUSH_CORE(T, T, T &&) #define DECL_GC(T) static int T ## _gc(lua_State * L) { static_cast(lua_touserdata(L, 1))->~T(); return 0; } -#define DECL_PRED(T) \ +#define DECL_PRED(T) \ bool is_ ## T(lua_State * L, int idx) { return testudata(L, idx, T ## _mt); } \ -static int T ## _pred(lua_State * L) { \ - lua_pushboolean(L, is_ ## T(L, 1)); \ - return 1; \ -} +static int T ## _pred(lua_State * L) { check_num_args(L, 1); return push_boolean(L, is_ ## T(L, 1)); } void throw_bad_arg_error(lua_State * L, int i, char const * expected_type); @@ -138,4 +135,13 @@ inline int push_integer(lua_State * L, lua_Integer v) { lua_pushinteger(L, v); r inline int push_number(lua_State * L, lua_Number v) { lua_pushnumber(L, v); return 1; } inline int push_nil(lua_State * L) { lua_pushnil(L); return 1; } // ======================================= + +// ======================================= +// Extra validation functions + +/** \brief Throw an exception if lua_gettop(L) != num */ +void check_num_args(lua_State * L, int num); +/** \brief Throw an exception if lua_gettop(L) > high */ +void check_atmost_num_args(lua_State * L, int high); +// ======================================= } diff --git a/src/util/name.cpp b/src/util/name.cpp index 1f6d96c9c..c9920f076 100644 --- a/src/util/name.cpp +++ b/src/util/name.cpp @@ -530,7 +530,7 @@ static int name_tostring(lua_State * L) { return push_string(L, to_name(L, 1).to static int name_eq(lua_State * L) { return push_boolean(L, to_name(L, 1) == to_name(L, 2)); } static int name_lt(lua_State * L) { return push_boolean(L, to_name(L, 1) < to_name(L, 2)); } static int name_hash(lua_State * L) { return push_integer(L, to_name(L, 1).hash()); } -#define NAME_PRED(P) static int name_ ## P(lua_State * L) { return push_boolean(L, to_name(L, 1).P()); } +#define NAME_PRED(P) static int name_ ## P(lua_State * L) { check_num_args(L, 1); return push_boolean(L, to_name(L, 1).P()); } NAME_PRED(is_atomic) NAME_PRED(is_anonymous) NAME_PRED(is_string) diff --git a/src/util/sexpr/sexpr.cpp b/src/util/sexpr/sexpr.cpp index d906481b2..af4c2b225 100644 --- a/src/util/sexpr/sexpr.cpp +++ b/src/util/sexpr/sexpr.cpp @@ -408,7 +408,7 @@ static int mk_sexpr(lua_State * L) { static int sexpr_eq(lua_State * L) { return push_boolean(L, to_sexpr(L, 1) == to_sexpr(L, 2)); } static int sexpr_lt(lua_State * L) { return push_boolean(L, to_sexpr(L, 1) < to_sexpr(L, 2)); } -#define SEXPR_PRED(P) static int sexpr_ ## P(lua_State * L) { return push_boolean(L, P(to_sexpr(L, 1))); } +#define SEXPR_PRED(P) static int sexpr_ ## P(lua_State * L) { check_num_args(L, 1); return push_boolean(L, P(to_sexpr(L, 1))); } SEXPR_PRED(is_nil) SEXPR_PRED(is_cons)