fix(lua/numerics): bug in bindings, add more tests
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
7359f360db
commit
d0bac61e74
2 changed files with 72 additions and 28 deletions
|
@ -7,6 +7,7 @@ Author: Leonardo de Moura
|
|||
#include <sstream>
|
||||
#include <lua.hpp>
|
||||
#include "util/debug.h"
|
||||
#include "util/sstream.h"
|
||||
#include "util/numerics/mpz.h"
|
||||
#include "util/numerics/mpq.h"
|
||||
#include "bindings/lua/util.h"
|
||||
|
@ -17,24 +18,20 @@ DECL_UDATA(mpz)
|
|||
template<unsigned idx>
|
||||
static mpz const & to_mpz(lua_State * L) {
|
||||
static thread_local mpz arg;
|
||||
if (lua_isuserdata(L, idx)) {
|
||||
return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
|
||||
} else if (lua_isstring(L, idx)) {
|
||||
arg = mpz(luaL_checkstring(L, idx));
|
||||
return arg;
|
||||
} else {
|
||||
arg = static_cast<long int>(luaL_checkinteger(L, 1));
|
||||
return arg;
|
||||
switch (lua_type(L, idx)) {
|
||||
case LUA_TNUMBER: arg = static_cast<long>(lua_tointeger(L, idx)); return arg;
|
||||
case LUA_TSTRING: arg = mpz(lua_tostring(L, idx)); return arg;
|
||||
case LUA_TUSERDATA: return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
|
||||
default: throw exception(sstream() << "arg #" << idx << " must be a number, string or mpz");
|
||||
}
|
||||
}
|
||||
|
||||
mpz to_mpz_ext(lua_State * L, int idx) {
|
||||
if (lua_isuserdata(L, idx)) {
|
||||
return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
|
||||
} else if (lua_isstring(L, idx)) {
|
||||
return mpz(luaL_checkstring(L, idx));
|
||||
} else {
|
||||
return mpz(static_cast<long int>(luaL_checkinteger(L, 1)));
|
||||
switch (lua_type(L, idx)) {
|
||||
case LUA_TNUMBER: return mpz(static_cast<long>(lua_tointeger(L, idx)));
|
||||
case LUA_TSTRING: return mpz(lua_tostring(L, idx));
|
||||
case LUA_TUSERDATA: return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
|
||||
default: throw exception(sstream() << "arg #" << idx << " must be a number, string or mpz");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -95,7 +92,7 @@ static const struct luaL_Reg mpz_m[] = {
|
|||
{"__eq", safe_function<mpz_eq>},
|
||||
{"__lt", safe_function<mpz_lt>},
|
||||
{"__add", safe_function<mpz_add>},
|
||||
{"__add", safe_function<mpz_sub>},
|
||||
{"__sub", safe_function<mpz_sub>},
|
||||
{"__mul", safe_function<mpz_mul>},
|
||||
{"__div", safe_function<mpz_div>},
|
||||
{"__pow", safe_function<mpz_power>},
|
||||
|
@ -116,31 +113,31 @@ DECL_UDATA(mpq)
|
|||
template<unsigned idx>
|
||||
static mpq const & to_mpq(lua_State * L) {
|
||||
static thread_local mpq arg;
|
||||
if (lua_isuserdata(L, idx)) {
|
||||
switch (lua_type(L, idx)) {
|
||||
case LUA_TNUMBER: arg = lua_tonumber(L, idx); return arg;
|
||||
case LUA_TSTRING: arg = mpq(lua_tostring(L, idx)); return arg;
|
||||
case LUA_TUSERDATA:
|
||||
if (is_mpz(L, idx)) {
|
||||
arg = mpq(to_mpz<idx>(L));
|
||||
return arg;
|
||||
} else {
|
||||
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
|
||||
}
|
||||
} else if (lua_isstring(L, idx)) {
|
||||
arg = mpq(luaL_checkstring(L, idx));
|
||||
} else {
|
||||
arg = static_cast<long int>(luaL_checkinteger(L, 1));
|
||||
default: throw exception(sstream() << "arg #" << idx << " must be a number, string, mpz or mpq");
|
||||
}
|
||||
return arg;
|
||||
}
|
||||
|
||||
mpq to_mpq_ext(lua_State * L, int idx) {
|
||||
if (lua_isuserdata(L, idx)) {
|
||||
switch (lua_type(L, idx)) {
|
||||
case LUA_TNUMBER: return mpq(lua_tonumber(L, idx));
|
||||
case LUA_TSTRING: return mpq(lua_tostring(L, idx));
|
||||
case LUA_TUSERDATA:
|
||||
if (is_mpz(L, idx)) {
|
||||
return mpq(to_mpz(L, idx));
|
||||
return mpq(to_mpz<1>(L));
|
||||
} else {
|
||||
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
|
||||
}
|
||||
} else if (lua_isstring(L, idx)) {
|
||||
return mpq(luaL_checkstring(L, idx));
|
||||
} else {
|
||||
return mpq(static_cast<long int>(luaL_checkinteger(L, 1)));
|
||||
default: throw exception(sstream() << "arg #" << idx << " must be a number, string, mpz or mpq");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -201,7 +198,7 @@ static const struct luaL_Reg mpq_m[] = {
|
|||
{"__eq", safe_function<mpq_eq>},
|
||||
{"__lt", safe_function<mpq_lt>},
|
||||
{"__add", safe_function<mpq_add>},
|
||||
{"__add", safe_function<mpq_sub>},
|
||||
{"__sub", safe_function<mpq_sub>},
|
||||
{"__mul", safe_function<mpq_mul>},
|
||||
{"__div", safe_function<mpq_div>},
|
||||
{"__pow", safe_function<mpq_power>},
|
||||
|
|
47
tests/lua/num1.lua
Normal file
47
tests/lua/num1.lua
Normal file
|
@ -0,0 +1,47 @@
|
|||
assert(mpz(10) == mpz("10"))
|
||||
assert(mpz(mpz(3)) == mpz("3"))
|
||||
print(mpz(10) + mpz(3))
|
||||
assert(mpz(10) + mpz(3) == mpz(13))
|
||||
assert(mpz(10) + 2 == mpz(12))
|
||||
assert(3 + mpz(15) == mpz(18))
|
||||
assert(mpz(10) - mpz(3) == mpz(7))
|
||||
assert(mpz(10) - 2 == mpz(8))
|
||||
assert(3 - mpz(15) == -mpz(12))
|
||||
assert(- mpz(15) == mpz("-15"))
|
||||
assert(- mpz(15) == -mpz("15"))
|
||||
assert(mpz(10) * mpz(3) == mpz(30))
|
||||
assert(mpz(10) * 2 == mpz(20))
|
||||
assert(3 * mpz(15) == mpz(45))
|
||||
assert(mpz(3)^2 == mpz(9))
|
||||
local a = -2
|
||||
assert(not pcall(function() print(mpz(3)^a) end))
|
||||
assert(mpz(3) < mpz(5))
|
||||
assert(mpz(3) > mpz(1))
|
||||
assert(mpq(3) == mpq("3"))
|
||||
assert(mpq(3) == mpq(mpq(3)))
|
||||
assert(mpq(3) == mpq(mpz(3)))
|
||||
assert(mpq(0.5) == mpq(1)/2)
|
||||
assert(mpq(0.4) ~= mpq(1)/2)
|
||||
assert(mpq(10) + mpq(3) == mpq(13))
|
||||
assert(mpq(10) + 2 == mpq(12))
|
||||
assert(3 + mpq(15) == mpq(18))
|
||||
assert(mpq(3) + mpz(15) == mpq(18))
|
||||
assert(mpq(10) - mpq(3) == mpq(7))
|
||||
assert(mpq(10) - 2 == mpq(8))
|
||||
assert(3 - mpq(15) == -mpq(12))
|
||||
assert(mpq(3) - mpz(15) == -mpq(12))
|
||||
assert(mpq(10) * mpq(3) == mpq(30))
|
||||
assert(mpq(10) * 2 == mpq(20))
|
||||
assert(3 * mpq(15) == mpq(45))
|
||||
assert(mpq(3) * mpz(15) == mpq(45))
|
||||
assert(mpq(3)^2 == mpq(9))
|
||||
print(mpq(0.5)^2)
|
||||
assert(mpq(0.5)^2 == mpq(1)/4)
|
||||
local a = -2
|
||||
assert(not pcall(function() print(mpq(3)^a) end))
|
||||
assert(mpq(10) / mpq(3) == mpq("10/3"))
|
||||
assert(mpq(10) / 2 == mpq(5))
|
||||
assert(3 / mpq(15) == mpq(1)/5)
|
||||
assert(mpq(3) / mpz(15) == mpq("1/5"))
|
||||
assert(mpq(3) < mpq(5))
|
||||
assert(mpq(3) > mpq(1))
|
Loading…
Reference in a new issue