/* Copyright (c) 2013 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #ifdef LEAN_USE_LUA #include #include #include "util/debug.h" #include "util/numerics/mpz.h" #include "util/numerics/mpq.h" #include "bindings/lua/util.h" namespace lean { template static mpz const & to_mpz(lua_State * L) { static thread_local mpz arg; if (lua_isuserdata(L, idx)) { return *static_cast(luaL_checkudata(L, idx, "mpz.mt")); } else if (lua_isstring(L, idx)) { char const * str = luaL_checkstring(L, idx); arg = mpz(str); return arg; } else { arg = luaL_checkinteger(L, 1); return arg; } } static int push_mpz(lua_State * L, mpz const & val) { void * mem = lua_newuserdata(L, sizeof(mpz)); new (mem) mpz(val); luaL_getmetatable(L, "mpz.mt"); lua_setmetatable(L, -2); return 1; } static int mpz_gc(lua_State * L) { mpz * n = static_cast(luaL_checkudata(L, 1, "mpz.mt")); n->~mpz(); return 0; } static int mpz_tostring(lua_State * L) { mpz * n = static_cast(luaL_checkudata(L, 1, "mpz.mt")); std::ostringstream out; out << *n; lua_pushfstring(L, out.str().c_str()); return 1; } static int mpz_eq(lua_State * L) { lua_pushboolean(L, to_mpz<1>(L) == to_mpz<2>(L)); return 1; } static int mpz_lt(lua_State * L) { lua_pushboolean(L, to_mpz<1>(L) < to_mpz<2>(L)); return 1; } static int mpz_add(lua_State * L) { return push_mpz(L, to_mpz<1>(L) + to_mpz<2>(L)); } static int mpz_sub(lua_State * L) { return push_mpz(L, to_mpz<1>(L) - to_mpz<2>(L)); } static int mpz_mul(lua_State * L) { return push_mpz(L, to_mpz<1>(L) * to_mpz<2>(L)); } static int mpz_div(lua_State * L) { mpz const & arg2 = to_mpz<2>(L); if (arg2 == 0) luaL_error(L, "division by zero"); return push_mpz(L, to_mpz<1>(L) / arg2); } static int mpz_umn(lua_State * L) { return push_mpz(L, 0 - to_mpz<1>(L)); } static int mpz_power(lua_State * L) { int k = luaL_checkinteger(L, 2); if (k < 0) luaL_error(L, "argument #2 must be positive"); return push_mpz(L, pow(to_mpz<1>(L), k)); } static int mk_mpz(lua_State * L) { mpz const & arg = to_mpz<1>(L); return push_mpz(L, arg); } static const struct luaL_Reg mpz_m[] = { {"__gc", mpz_gc}, // never throws {"__tostring", safe_function}, {"__eq", safe_function}, {"__lt", safe_function}, {"__add", safe_function}, {"__add", safe_function}, {"__mul", safe_function}, {"__div", safe_function}, {"__pow", safe_function}, {"__unm", safe_function}, {0, 0} }; void open_mpz(lua_State * L) { luaL_newmetatable(L, "mpz.mt"); setfuncs(L, mpz_m, 0); lua_pushcfunction(L, safe_function); lua_setglobal(L, "mpz"); } template static mpq const & to_mpq(lua_State * L) { static thread_local mpq arg; if (lua_isuserdata(L, idx)) { return *static_cast(luaL_checkudata(L, idx, "mpq.mt")); } else if (lua_isstring(L, idx)) { char const * str = luaL_checkstring(L, idx); arg = mpq(str); return arg; } else { arg = luaL_checkinteger(L, 1); return arg; } } static int push_mpq(lua_State * L, mpq const & val) { void * mem = lua_newuserdata(L, sizeof(mpq)); new (mem) mpq(val); luaL_getmetatable(L, "mpq.mt"); lua_setmetatable(L, -2); return 1; } static int mpq_gc(lua_State * L) { mpq * n = static_cast(luaL_checkudata(L, 1, "mpq.mt")); n->~mpq(); return 0; } static int mpq_tostring(lua_State * L) { mpq * n = static_cast(luaL_checkudata(L, 1, "mpq.mt")); std::ostringstream out; out << *n; lua_pushfstring(L, out.str().c_str()); return 1; } static int mpq_eq(lua_State * L) { lua_pushboolean(L, to_mpq<1>(L) == to_mpq<2>(L)); return 1; } static int mpq_lt(lua_State * L) { lua_pushboolean(L, to_mpq<1>(L) < to_mpq<2>(L)); return 1; } static int mpq_add(lua_State * L) { return push_mpq(L, to_mpq<1>(L) + to_mpq<2>(L)); } static int mpq_sub(lua_State * L) { return push_mpq(L, to_mpq<1>(L) - to_mpq<2>(L)); } static int mpq_mul(lua_State * L) { return push_mpq(L, to_mpq<1>(L) * to_mpq<2>(L)); } static int mpq_div(lua_State * L) { mpq const & arg2 = to_mpq<2>(L); if (arg2 == 0) luaL_error(L, "division by zero"); return push_mpq(L, to_mpq<1>(L) / arg2); } static int mpq_umn(lua_State * L) { return push_mpq(L, 0 - to_mpq<1>(L)); } static int mpq_power(lua_State * L) { int k = luaL_checkinteger(L, 2); if (k < 0) luaL_error(L, "argument #2 must be positive"); return push_mpq(L, pow(to_mpq<1>(L), k)); } static int mk_mpq(lua_State * L) { mpq const & arg = to_mpq<1>(L); return push_mpq(L, arg); } static const struct luaL_Reg mpq_m[] = { {"__gc", mpq_gc}, // never throws {"__tostring", safe_function}, {"__eq", safe_function}, {"__lt", safe_function}, {"__add", safe_function}, {"__add", safe_function}, {"__mul", safe_function}, {"__div", safe_function}, {"__pow", safe_function}, {"__unm", safe_function}, {0, 0} }; void open_mpq(lua_State * L) { luaL_newmetatable(L, "mpq.mt"); setfuncs(L, mpq_m, 0); lua_pushcfunction(L, safe_function); lua_setglobal(L, "mpq"); } } #endif