From e2da8c1f4de10ac826cdb270e357a658c4df4f4e Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 3 Nov 2013 12:02:57 -0800 Subject: [PATCH] feat(lua/numerics): expose mpz and mpq numbers in the Lua bindings Signed-off-by: Leonardo de Moura --- src/bindings/lua/CMakeLists.txt | 2 +- src/bindings/lua/numerics.cpp | 129 ++++++++++++++++++++++++++++++++ src/bindings/lua/numerics.h | 13 ++++ src/shell/lua/leanlua.cpp | 3 + tests/lua/mpz1.lua | 11 +++ 5 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 src/bindings/lua/numerics.cpp create mode 100644 src/bindings/lua/numerics.h create mode 100644 tests/lua/mpz1.lua diff --git a/src/bindings/lua/CMakeLists.txt b/src/bindings/lua/CMakeLists.txt index 77de6e18d..c19ed5737 100644 --- a/src/bindings/lua/CMakeLists.txt +++ b/src/bindings/lua/CMakeLists.txt @@ -1 +1 @@ -add_library(lua name.cpp) +add_library(lua name.cpp numerics.cpp) diff --git a/src/bindings/lua/numerics.cpp b/src/bindings/lua/numerics.cpp new file mode 100644 index 000000000..a232d84b7 --- /dev/null +++ b/src/bindings/lua/numerics.cpp @@ -0,0 +1,129 @@ +/* +Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#ifdef LEAN_USE_LUA +#include +#include +#include "util/debug.h" +#include "util/numerics/mpz.h" +#include "util/numerics/mpq.h" + +namespace lean { +template +class num_bindings { +public: + template + static T const & get_arg(lua_State * L) { + static thread_local T arg; + if (lua_isuserdata(L, idx)) { + return *static_cast(luaL_checkudata(L, idx, M)); + } else if (lua_isstring(L, idx)) { + char const * str = luaL_checkstring(L, idx); + arg = T(str); + return arg; + } else { + arg = luaL_checkinteger(L, 1); + return arg; + } + } + + static int push_result(lua_State * L, T const & val) { + void * mem = lua_newuserdata(L, sizeof(T)); + new (mem) T(val); + luaL_getmetatable(L, M); + lua_setmetatable(L, -2); + return 1; + } + + static int gc(lua_State * L) { + T * n = static_cast(luaL_checkudata(L, 1, M)); + n->~T(); + return 0; + } + + static int tostring(lua_State * L) { + T * n = static_cast(luaL_checkudata(L, 1, M)); + std::ostringstream out; + out << *n; + lua_pushfstring(L, out.str().c_str()); + return 1; + } + + static int eq(lua_State * L) { + lua_pushboolean(L, get_arg<1>(L) == get_arg<2>(L)); + return 1; + } + + static int lt(lua_State * L) { + lua_pushboolean(L, get_arg<1>(L) < get_arg<2>(L)); + return 1; + } + + static int add(lua_State * L) { + return push_result(L, get_arg<1>(L) + get_arg<2>(L)); + } + + static int sub(lua_State * L) { + return push_result(L, get_arg<1>(L) - get_arg<2>(L)); + } + + static int mul(lua_State * L) { + return push_result(L, get_arg<1>(L) * get_arg<2>(L)); + } + + static int div(lua_State * L) { + T const & arg2 = get_arg<2>(L); + if (arg2 == 0) luaL_error(L, "division by zero"); + return push_result(L, get_arg<1>(L) / arg2); + } + + static int umn(lua_State * L) { + return push_result(L, 0 - get_arg<1>(L)); + } + + static int power(lua_State * L) { + int k = luaL_checkinteger(L, 2); + if (k < 0) luaL_error(L, "argument #2 must be positive"); + return push_result(L, pow(get_arg<1>(L), k)); + } + + static const struct luaL_Reg m[]; + + static int mk(lua_State * L) { + T const & arg = get_arg<1>(L); + return push_result(L, arg); + } + + static void init(lua_State * L) { + luaL_newmetatable(L, M); + luaL_setfuncs(L, m, 0); + + lua_pushcfunction(L, mk); + lua_setglobal(L, N); + } +}; + +template +const struct luaL_Reg num_bindings::m[] = { + {"__gc", num_bindings::gc}, {"__tostring", num_bindings::tostring}, {"__eq", num_bindings::eq}, + {"__lt", num_bindings::lt}, {"__add", num_bindings::add}, {"__add", num_bindings::sub}, + {"__mul", num_bindings::mul}, {"__div", num_bindings::div}, {"__pow", num_bindings::power}, + {"__unm", num_bindings::umn}, + {0, 0} +}; + +constexpr char const mpz_name[] = "mpz"; +constexpr char const mpz_metatable[] = "mpz.mt"; +void init_mpz(lua_State * L) { + num_bindings::init(L); +} +constexpr char const mpq_name[] = "mpq"; +constexpr char const mpq_metatable[] = "mpq.mt"; +void init_mpq(lua_State * L) { + num_bindings::init(L); +} +} +#endif diff --git a/src/bindings/lua/numerics.h b/src/bindings/lua/numerics.h new file mode 100644 index 000000000..7e95c2160 --- /dev/null +++ b/src/bindings/lua/numerics.h @@ -0,0 +1,13 @@ +/* +Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#ifdef LEAN_USE_LUA +#include +namespace lean { +void init_mpz(lua_State * L); +void init_mpq(lua_State * L); +} +#endif diff --git a/src/shell/lua/leanlua.cpp b/src/shell/lua/leanlua.cpp index 04d0c6ff4..46e104835 100644 --- a/src/shell/lua/leanlua.cpp +++ b/src/shell/lua/leanlua.cpp @@ -10,6 +10,7 @@ Author: Leonardo de Moura #ifdef LEAN_USE_LUA #include #include "bindings/lua/name.h" +#include "bindings/lua/numerics.h" int main(int argc, char ** argv) { int status, result; @@ -18,6 +19,8 @@ int main(int argc, char ** argv) { L = luaL_newstate(); luaL_openlibs(L); lean::init_name(L); + lean::init_mpz(L); + lean::init_mpq(L); for (int i = 1; i < argc; i++) { status = luaL_loadfile(L, argv[i]); diff --git a/tests/lua/mpz1.lua b/tests/lua/mpz1.lua new file mode 100644 index 000000000..8eefbd87f --- /dev/null +++ b/tests/lua/mpz1.lua @@ -0,0 +1,11 @@ +a = mpz("1000000000000000000000000000") +b = mpz(10) +c = mpz(10) +print(a/3) +print(b == c) +print(a < b) +print(b < a) +print(10 < b) +print(10 <= b) +print(a) +print(b)