fix(lua): problem when compiling with clang++

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-11-04 15:05:04 -08:00
parent 0579970fc5
commit 32d3990fc7
2 changed files with 182 additions and 100 deletions

View file

@ -13,125 +13,208 @@ Author: Leonardo de Moura
#include "bindings/lua/util.h"
namespace lean {
template<typename T, char const * N, char const * M>
class num_bindings {
public:
template<unsigned idx>
static T const & get_arg(lua_State * L) {
static thread_local T arg;
if (lua_isuserdata(L, idx)) {
return *static_cast<T*>(luaL_checkudata(L, idx, M));
} else if (lua_isstring(L, idx)) {
char const * str = luaL_checkstring(L, idx);
arg = T(str);
return arg;
} else {
arg = luaL_checkinteger(L, 1);
return arg;
}
template<unsigned idx>
static mpz const & to_mpz(lua_State * L) {
static thread_local mpz arg;
if (lua_isuserdata(L, idx)) {
return *static_cast<mpz*>(luaL_checkudata(L, idx, "mpz.mt"));
} else if (lua_isstring(L, idx)) {
char const * str = luaL_checkstring(L, idx);
arg = mpz(str);
return arg;
} else {
arg = luaL_checkinteger(L, 1);
return arg;
}
}
static int push_result(lua_State * L, T const & val) {
void * mem = lua_newuserdata(L, sizeof(T));
new (mem) T(val);
luaL_getmetatable(L, M);
lua_setmetatable(L, -2);
return 1;
}
static int push_mpz(lua_State * L, mpz const & val) {
void * mem = lua_newuserdata(L, sizeof(mpz));
new (mem) mpz(val);
luaL_getmetatable(L, "mpz.mt");
lua_setmetatable(L, -2);
return 1;
}
static int gc(lua_State * L) {
T * n = static_cast<T*>(luaL_checkudata(L, 1, M));
n->~T();
return 0;
}
static int mpz_gc(lua_State * L) {
mpz * n = static_cast<mpz*>(luaL_checkudata(L, 1, "mpz.mt"));
n->~mpz();
return 0;
}
static int tostring(lua_State * L) {
T * n = static_cast<T*>(luaL_checkudata(L, 1, M));
std::ostringstream out;
out << *n;
lua_pushfstring(L, out.str().c_str());
return 1;
}
static int mpz_tostring(lua_State * L) {
mpz * n = static_cast<mpz*>(luaL_checkudata(L, 1, "mpz.mt"));
std::ostringstream out;
out << *n;
lua_pushfstring(L, out.str().c_str());
return 1;
}
static int eq(lua_State * L) {
lua_pushboolean(L, get_arg<1>(L) == get_arg<2>(L));
return 1;
}
static int mpz_eq(lua_State * L) {
lua_pushboolean(L, to_mpz<1>(L) == to_mpz<2>(L));
return 1;
}
static int lt(lua_State * L) {
lua_pushboolean(L, get_arg<1>(L) < get_arg<2>(L));
return 1;
}
static int mpz_lt(lua_State * L) {
lua_pushboolean(L, to_mpz<1>(L) < to_mpz<2>(L));
return 1;
}
static int add(lua_State * L) {
return push_result(L, get_arg<1>(L) + get_arg<2>(L));
}
static int mpz_add(lua_State * L) {
return push_mpz(L, to_mpz<1>(L) + to_mpz<2>(L));
}
static int sub(lua_State * L) {
return push_result(L, get_arg<1>(L) - get_arg<2>(L));
}
static int mpz_sub(lua_State * L) {
return push_mpz(L, to_mpz<1>(L) - to_mpz<2>(L));
}
static int mul(lua_State * L) {
return push_result(L, get_arg<1>(L) * get_arg<2>(L));
}
static int mpz_mul(lua_State * L) {
return push_mpz(L, to_mpz<1>(L) * to_mpz<2>(L));
}
static int div(lua_State * L) {
T const & arg2 = get_arg<2>(L);
if (arg2 == 0) luaL_error(L, "division by zero");
return push_result(L, get_arg<1>(L) / arg2);
}
static int mpz_div(lua_State * L) {
mpz const & arg2 = to_mpz<2>(L);
if (arg2 == 0) luaL_error(L, "division by zero");
return push_mpz(L, to_mpz<1>(L) / arg2);
}
static int umn(lua_State * L) {
return push_result(L, 0 - get_arg<1>(L));
}
static int mpz_umn(lua_State * L) {
return push_mpz(L, 0 - to_mpz<1>(L));
}
static int power(lua_State * L) {
int k = luaL_checkinteger(L, 2);
if (k < 0) luaL_error(L, "argument #2 must be positive");
return push_result(L, pow(get_arg<1>(L), k));
}
static int mpz_power(lua_State * L) {
int k = luaL_checkinteger(L, 2);
if (k < 0) luaL_error(L, "argument #2 must be positive");
return push_mpz(L, pow(to_mpz<1>(L), k));
}
static const struct luaL_Reg m[];
static int mk_mpz(lua_State * L) {
mpz const & arg = to_mpz<1>(L);
return push_mpz(L, arg);
}
static int mk(lua_State * L) {
T const & arg = get_arg<1>(L);
return push_result(L, arg);
}
static void open(lua_State * L) {
luaL_newmetatable(L, M);
setfuncs(L, m, 0);
lua_pushcfunction(L, safe_function<mk>);
lua_setglobal(L, N);
}
};
template<typename T, char const * N, char const * M>
const struct luaL_Reg num_bindings<T, N, M>::m[] = {
{"__gc", num_bindings<T, N, M>::gc}, // never throws
{"__tostring", safe_function<num_bindings<T, N, M>::tostring>},
{"__eq", safe_function<num_bindings<T, N, M>::eq>},
{"__lt", safe_function<num_bindings<T, N, M>::lt>},
{"__add", safe_function<num_bindings<T, N, M>::add>},
{"__add", safe_function<num_bindings<T, N, M>::sub>},
{"__mul", safe_function<num_bindings<T, N, M>::mul>},
{"__div", safe_function<num_bindings<T, N, M>::div>},
{"__pow", safe_function<num_bindings<T, N, M>::power>},
{"__unm", safe_function<num_bindings<T, N, M>::umn>},
static const struct luaL_Reg mpz_m[] = {
{"__gc", mpz_gc}, // never throws
{"__tostring", safe_function<mpz_tostring>},
{"__eq", safe_function<mpz_eq>},
{"__lt", safe_function<mpz_lt>},
{"__add", safe_function<mpz_add>},
{"__add", safe_function<mpz_sub>},
{"__mul", safe_function<mpz_mul>},
{"__div", safe_function<mpz_div>},
{"__pow", safe_function<mpz_power>},
{"__unm", safe_function<mpz_umn>},
{0, 0}
};
constexpr char const mpz_name[] = "mpz";
constexpr char const mpz_metatable[] = "mpz.mt";
void open_mpz(lua_State * L) {
num_bindings<mpz, mpz_name, mpz_metatable>::open(L);
luaL_newmetatable(L, "mpz.mt");
setfuncs(L, mpz_m, 0);
lua_pushcfunction(L, safe_function<mk_mpz>);
lua_setglobal(L, "mpz");
}
constexpr char const mpq_name[] = "mpq";
constexpr char const mpq_metatable[] = "mpq.mt";
template<unsigned idx>
static mpq const & to_mpq(lua_State * L) {
static thread_local mpq arg;
if (lua_isuserdata(L, idx)) {
return *static_cast<mpq*>(luaL_checkudata(L, idx, "mpq.mt"));
} else if (lua_isstring(L, idx)) {
char const * str = luaL_checkstring(L, idx);
arg = mpq(str);
return arg;
} else {
arg = luaL_checkinteger(L, 1);
return arg;
}
}
static int push_mpq(lua_State * L, mpq const & val) {
void * mem = lua_newuserdata(L, sizeof(mpq));
new (mem) mpq(val);
luaL_getmetatable(L, "mpq.mt");
lua_setmetatable(L, -2);
return 1;
}
static int mpq_gc(lua_State * L) {
mpq * n = static_cast<mpq*>(luaL_checkudata(L, 1, "mpq.mt"));
n->~mpq();
return 0;
}
static int mpq_tostring(lua_State * L) {
mpq * n = static_cast<mpq*>(luaL_checkudata(L, 1, "mpq.mt"));
std::ostringstream out;
out << *n;
lua_pushfstring(L, out.str().c_str());
return 1;
}
static int mpq_eq(lua_State * L) {
lua_pushboolean(L, to_mpq<1>(L) == to_mpq<2>(L));
return 1;
}
static int mpq_lt(lua_State * L) {
lua_pushboolean(L, to_mpq<1>(L) < to_mpq<2>(L));
return 1;
}
static int mpq_add(lua_State * L) {
return push_mpq(L, to_mpq<1>(L) + to_mpq<2>(L));
}
static int mpq_sub(lua_State * L) {
return push_mpq(L, to_mpq<1>(L) - to_mpq<2>(L));
}
static int mpq_mul(lua_State * L) {
return push_mpq(L, to_mpq<1>(L) * to_mpq<2>(L));
}
static int mpq_div(lua_State * L) {
mpq const & arg2 = to_mpq<2>(L);
if (arg2 == 0) luaL_error(L, "division by zero");
return push_mpq(L, to_mpq<1>(L) / arg2);
}
static int mpq_umn(lua_State * L) {
return push_mpq(L, 0 - to_mpq<1>(L));
}
static int mpq_power(lua_State * L) {
int k = luaL_checkinteger(L, 2);
if (k < 0) luaL_error(L, "argument #2 must be positive");
return push_mpq(L, pow(to_mpq<1>(L), k));
}
static int mk_mpq(lua_State * L) {
mpq const & arg = to_mpq<1>(L);
return push_mpq(L, arg);
}
static const struct luaL_Reg mpq_m[] = {
{"__gc", mpq_gc}, // never throws
{"__tostring", safe_function<mpq_tostring>},
{"__eq", safe_function<mpq_eq>},
{"__lt", safe_function<mpq_lt>},
{"__add", safe_function<mpq_add>},
{"__add", safe_function<mpq_sub>},
{"__mul", safe_function<mpq_mul>},
{"__div", safe_function<mpq_div>},
{"__pow", safe_function<mpq_power>},
{"__unm", safe_function<mpq_umn>},
{0, 0}
};
void open_mpq(lua_State * L) {
num_bindings<mpq, mpq_name, mpq_metatable>::open(L);
luaL_newmetatable(L, "mpq.mt");
setfuncs(L, mpq_m, 0);
lua_pushcfunction(L, safe_function<mk_mpq>);
lua_setglobal(L, "mpq");
}
}
#endif

View file

@ -12,8 +12,7 @@ void setfuncs(lua_State * L, luaL_Reg const * l, int nup);
\brief Wrapper for invoking function f, and catching Lean exceptions.
*/
int safe_function_wrapper(lua_State * L, lua_CFunction f);
template<lua_CFunction F>
int safe_function(lua_State * L) {
template<lua_CFunction F> int safe_function(lua_State * L) {
return safe_function_wrapper(L, F);
}
}