diff --git a/src/bindings/lua/numerics.cpp b/src/bindings/lua/numerics.cpp index 866717a23..bd9c7fa08 100644 --- a/src/bindings/lua/numerics.cpp +++ b/src/bindings/lua/numerics.cpp @@ -7,6 +7,7 @@ Author: Leonardo de Moura #include #include #include "util/debug.h" +#include "util/sstream.h" #include "util/numerics/mpz.h" #include "util/numerics/mpq.h" #include "bindings/lua/util.h" @@ -17,24 +18,20 @@ DECL_UDATA(mpz) template static mpz const & to_mpz(lua_State * L) { static thread_local mpz arg; - if (lua_isuserdata(L, idx)) { - return *static_cast(luaL_checkudata(L, idx, mpz_mt)); - } else if (lua_isstring(L, idx)) { - arg = mpz(luaL_checkstring(L, idx)); - return arg; - } else { - arg = static_cast(luaL_checkinteger(L, 1)); - return arg; + switch (lua_type(L, idx)) { + case LUA_TNUMBER: arg = static_cast(lua_tointeger(L, idx)); return arg; + case LUA_TSTRING: arg = mpz(lua_tostring(L, idx)); return arg; + case LUA_TUSERDATA: return *static_cast(luaL_checkudata(L, idx, mpz_mt)); + default: throw exception(sstream() << "arg #" << idx << " must be a number, string or mpz"); } } mpz to_mpz_ext(lua_State * L, int idx) { - if (lua_isuserdata(L, idx)) { - return *static_cast(luaL_checkudata(L, idx, mpz_mt)); - } else if (lua_isstring(L, idx)) { - return mpz(luaL_checkstring(L, idx)); - } else { - return mpz(static_cast(luaL_checkinteger(L, 1))); + switch (lua_type(L, idx)) { + case LUA_TNUMBER: return mpz(static_cast(lua_tointeger(L, idx))); + case LUA_TSTRING: return mpz(lua_tostring(L, idx)); + case LUA_TUSERDATA: return *static_cast(luaL_checkudata(L, idx, mpz_mt)); + default: throw exception(sstream() << "arg #" << idx << " must be a number, string or mpz"); } } @@ -95,7 +92,7 @@ static const struct luaL_Reg mpz_m[] = { {"__eq", safe_function}, {"__lt", safe_function}, {"__add", safe_function}, - {"__add", safe_function}, + {"__sub", safe_function}, {"__mul", safe_function}, {"__div", safe_function}, {"__pow", safe_function}, @@ -116,31 +113,31 @@ DECL_UDATA(mpq) template static mpq const & to_mpq(lua_State * L) { static thread_local mpq arg; - if (lua_isuserdata(L, idx)) { + switch (lua_type(L, idx)) { + case LUA_TNUMBER: arg = lua_tonumber(L, idx); return arg; + case LUA_TSTRING: arg = mpq(lua_tostring(L, idx)); return arg; + case LUA_TUSERDATA: if (is_mpz(L, idx)) { arg = mpq(to_mpz(L)); + return arg; } else { return *static_cast(luaL_checkudata(L, idx, mpq_mt)); } - } else if (lua_isstring(L, idx)) { - arg = mpq(luaL_checkstring(L, idx)); - } else { - arg = static_cast(luaL_checkinteger(L, 1)); + default: throw exception(sstream() << "arg #" << idx << " must be a number, string, mpz or mpq"); } - return arg; } mpq to_mpq_ext(lua_State * L, int idx) { - if (lua_isuserdata(L, idx)) { + switch (lua_type(L, idx)) { + case LUA_TNUMBER: return mpq(lua_tonumber(L, idx)); + case LUA_TSTRING: return mpq(lua_tostring(L, idx)); + case LUA_TUSERDATA: if (is_mpz(L, idx)) { - return mpq(to_mpz(L, idx)); + return mpq(to_mpz<1>(L)); } else { return *static_cast(luaL_checkudata(L, idx, mpq_mt)); } - } else if (lua_isstring(L, idx)) { - return mpq(luaL_checkstring(L, idx)); - } else { - return mpq(static_cast(luaL_checkinteger(L, 1))); + default: throw exception(sstream() << "arg #" << idx << " must be a number, string, mpz or mpq"); } } @@ -201,7 +198,7 @@ static const struct luaL_Reg mpq_m[] = { {"__eq", safe_function}, {"__lt", safe_function}, {"__add", safe_function}, - {"__add", safe_function}, + {"__sub", safe_function}, {"__mul", safe_function}, {"__div", safe_function}, {"__pow", safe_function}, diff --git a/tests/lua/num1.lua b/tests/lua/num1.lua new file mode 100644 index 000000000..83119a0c6 --- /dev/null +++ b/tests/lua/num1.lua @@ -0,0 +1,47 @@ +assert(mpz(10) == mpz("10")) +assert(mpz(mpz(3)) == mpz("3")) +print(mpz(10) + mpz(3)) +assert(mpz(10) + mpz(3) == mpz(13)) +assert(mpz(10) + 2 == mpz(12)) +assert(3 + mpz(15) == mpz(18)) +assert(mpz(10) - mpz(3) == mpz(7)) +assert(mpz(10) - 2 == mpz(8)) +assert(3 - mpz(15) == -mpz(12)) +assert(- mpz(15) == mpz("-15")) +assert(- mpz(15) == -mpz("15")) +assert(mpz(10) * mpz(3) == mpz(30)) +assert(mpz(10) * 2 == mpz(20)) +assert(3 * mpz(15) == mpz(45)) +assert(mpz(3)^2 == mpz(9)) +local a = -2 +assert(not pcall(function() print(mpz(3)^a) end)) +assert(mpz(3) < mpz(5)) +assert(mpz(3) > mpz(1)) +assert(mpq(3) == mpq("3")) +assert(mpq(3) == mpq(mpq(3))) +assert(mpq(3) == mpq(mpz(3))) +assert(mpq(0.5) == mpq(1)/2) +assert(mpq(0.4) ~= mpq(1)/2) +assert(mpq(10) + mpq(3) == mpq(13)) +assert(mpq(10) + 2 == mpq(12)) +assert(3 + mpq(15) == mpq(18)) +assert(mpq(3) + mpz(15) == mpq(18)) +assert(mpq(10) - mpq(3) == mpq(7)) +assert(mpq(10) - 2 == mpq(8)) +assert(3 - mpq(15) == -mpq(12)) +assert(mpq(3) - mpz(15) == -mpq(12)) +assert(mpq(10) * mpq(3) == mpq(30)) +assert(mpq(10) * 2 == mpq(20)) +assert(3 * mpq(15) == mpq(45)) +assert(mpq(3) * mpz(15) == mpq(45)) +assert(mpq(3)^2 == mpq(9)) +print(mpq(0.5)^2) +assert(mpq(0.5)^2 == mpq(1)/4) +local a = -2 +assert(not pcall(function() print(mpq(3)^a) end)) +assert(mpq(10) / mpq(3) == mpq("10/3")) +assert(mpq(10) / 2 == mpq(5)) +assert(3 / mpq(15) == mpq(1)/5) +assert(mpq(3) / mpz(15) == mpq("1/5")) +assert(mpq(3) < mpq(5)) +assert(mpq(3) > mpq(1))