fix(lua/numerics): bug in bindings, add more tests

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-11-17 11:02:44 -08:00
parent 7359f360db
commit d0bac61e74
2 changed files with 72 additions and 28 deletions

View file

@ -7,6 +7,7 @@ Author: Leonardo de Moura
#include <sstream> #include <sstream>
#include <lua.hpp> #include <lua.hpp>
#include "util/debug.h" #include "util/debug.h"
#include "util/sstream.h"
#include "util/numerics/mpz.h" #include "util/numerics/mpz.h"
#include "util/numerics/mpq.h" #include "util/numerics/mpq.h"
#include "bindings/lua/util.h" #include "bindings/lua/util.h"
@ -17,24 +18,20 @@ DECL_UDATA(mpz)
template<unsigned idx> template<unsigned idx>
static mpz const & to_mpz(lua_State * L) { static mpz const & to_mpz(lua_State * L) {
static thread_local mpz arg; static thread_local mpz arg;
if (lua_isuserdata(L, idx)) { switch (lua_type(L, idx)) {
return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt)); case LUA_TNUMBER: arg = static_cast<long>(lua_tointeger(L, idx)); return arg;
} else if (lua_isstring(L, idx)) { case LUA_TSTRING: arg = mpz(lua_tostring(L, idx)); return arg;
arg = mpz(luaL_checkstring(L, idx)); case LUA_TUSERDATA: return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
return arg; default: throw exception(sstream() << "arg #" << idx << " must be a number, string or mpz");
} else {
arg = static_cast<long int>(luaL_checkinteger(L, 1));
return arg;
} }
} }
mpz to_mpz_ext(lua_State * L, int idx) { mpz to_mpz_ext(lua_State * L, int idx) {
if (lua_isuserdata(L, idx)) { switch (lua_type(L, idx)) {
return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt)); case LUA_TNUMBER: return mpz(static_cast<long>(lua_tointeger(L, idx)));
} else if (lua_isstring(L, idx)) { case LUA_TSTRING: return mpz(lua_tostring(L, idx));
return mpz(luaL_checkstring(L, idx)); case LUA_TUSERDATA: return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
} else { default: throw exception(sstream() << "arg #" << idx << " must be a number, string or mpz");
return mpz(static_cast<long int>(luaL_checkinteger(L, 1)));
} }
} }
@ -95,7 +92,7 @@ static const struct luaL_Reg mpz_m[] = {
{"__eq", safe_function<mpz_eq>}, {"__eq", safe_function<mpz_eq>},
{"__lt", safe_function<mpz_lt>}, {"__lt", safe_function<mpz_lt>},
{"__add", safe_function<mpz_add>}, {"__add", safe_function<mpz_add>},
{"__add", safe_function<mpz_sub>}, {"__sub", safe_function<mpz_sub>},
{"__mul", safe_function<mpz_mul>}, {"__mul", safe_function<mpz_mul>},
{"__div", safe_function<mpz_div>}, {"__div", safe_function<mpz_div>},
{"__pow", safe_function<mpz_power>}, {"__pow", safe_function<mpz_power>},
@ -116,31 +113,31 @@ DECL_UDATA(mpq)
template<unsigned idx> template<unsigned idx>
static mpq const & to_mpq(lua_State * L) { static mpq const & to_mpq(lua_State * L) {
static thread_local mpq arg; static thread_local mpq arg;
if (lua_isuserdata(L, idx)) { switch (lua_type(L, idx)) {
case LUA_TNUMBER: arg = lua_tonumber(L, idx); return arg;
case LUA_TSTRING: arg = mpq(lua_tostring(L, idx)); return arg;
case LUA_TUSERDATA:
if (is_mpz(L, idx)) { if (is_mpz(L, idx)) {
arg = mpq(to_mpz<idx>(L)); arg = mpq(to_mpz<idx>(L));
return arg;
} else { } else {
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt)); return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
} }
} else if (lua_isstring(L, idx)) { default: throw exception(sstream() << "arg #" << idx << " must be a number, string, mpz or mpq");
arg = mpq(luaL_checkstring(L, idx));
} else {
arg = static_cast<long int>(luaL_checkinteger(L, 1));
} }
return arg;
} }
mpq to_mpq_ext(lua_State * L, int idx) { mpq to_mpq_ext(lua_State * L, int idx) {
if (lua_isuserdata(L, idx)) { switch (lua_type(L, idx)) {
case LUA_TNUMBER: return mpq(lua_tonumber(L, idx));
case LUA_TSTRING: return mpq(lua_tostring(L, idx));
case LUA_TUSERDATA:
if (is_mpz(L, idx)) { if (is_mpz(L, idx)) {
return mpq(to_mpz(L, idx)); return mpq(to_mpz<1>(L));
} else { } else {
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt)); return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
} }
} else if (lua_isstring(L, idx)) { default: throw exception(sstream() << "arg #" << idx << " must be a number, string, mpz or mpq");
return mpq(luaL_checkstring(L, idx));
} else {
return mpq(static_cast<long int>(luaL_checkinteger(L, 1)));
} }
} }
@ -201,7 +198,7 @@ static const struct luaL_Reg mpq_m[] = {
{"__eq", safe_function<mpq_eq>}, {"__eq", safe_function<mpq_eq>},
{"__lt", safe_function<mpq_lt>}, {"__lt", safe_function<mpq_lt>},
{"__add", safe_function<mpq_add>}, {"__add", safe_function<mpq_add>},
{"__add", safe_function<mpq_sub>}, {"__sub", safe_function<mpq_sub>},
{"__mul", safe_function<mpq_mul>}, {"__mul", safe_function<mpq_mul>},
{"__div", safe_function<mpq_div>}, {"__div", safe_function<mpq_div>},
{"__pow", safe_function<mpq_power>}, {"__pow", safe_function<mpq_power>},

47
tests/lua/num1.lua Normal file
View file

@ -0,0 +1,47 @@
assert(mpz(10) == mpz("10"))
assert(mpz(mpz(3)) == mpz("3"))
print(mpz(10) + mpz(3))
assert(mpz(10) + mpz(3) == mpz(13))
assert(mpz(10) + 2 == mpz(12))
assert(3 + mpz(15) == mpz(18))
assert(mpz(10) - mpz(3) == mpz(7))
assert(mpz(10) - 2 == mpz(8))
assert(3 - mpz(15) == -mpz(12))
assert(- mpz(15) == mpz("-15"))
assert(- mpz(15) == -mpz("15"))
assert(mpz(10) * mpz(3) == mpz(30))
assert(mpz(10) * 2 == mpz(20))
assert(3 * mpz(15) == mpz(45))
assert(mpz(3)^2 == mpz(9))
local a = -2
assert(not pcall(function() print(mpz(3)^a) end))
assert(mpz(3) < mpz(5))
assert(mpz(3) > mpz(1))
assert(mpq(3) == mpq("3"))
assert(mpq(3) == mpq(mpq(3)))
assert(mpq(3) == mpq(mpz(3)))
assert(mpq(0.5) == mpq(1)/2)
assert(mpq(0.4) ~= mpq(1)/2)
assert(mpq(10) + mpq(3) == mpq(13))
assert(mpq(10) + 2 == mpq(12))
assert(3 + mpq(15) == mpq(18))
assert(mpq(3) + mpz(15) == mpq(18))
assert(mpq(10) - mpq(3) == mpq(7))
assert(mpq(10) - 2 == mpq(8))
assert(3 - mpq(15) == -mpq(12))
assert(mpq(3) - mpz(15) == -mpq(12))
assert(mpq(10) * mpq(3) == mpq(30))
assert(mpq(10) * 2 == mpq(20))
assert(3 * mpq(15) == mpq(45))
assert(mpq(3) * mpz(15) == mpq(45))
assert(mpq(3)^2 == mpq(9))
print(mpq(0.5)^2)
assert(mpq(0.5)^2 == mpq(1)/4)
local a = -2
assert(not pcall(function() print(mpq(3)^a) end))
assert(mpq(10) / mpq(3) == mpq("10/3"))
assert(mpq(10) / 2 == mpq(5))
assert(3 / mpq(15) == mpq(1)/5)
assert(mpq(3) / mpz(15) == mpq("1/5"))
assert(mpq(3) < mpq(5))
assert(mpq(3) > mpq(1))