From 69b41eae707fc81ee118f708d65acb42847d7f2c Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 11 Nov 2013 16:25:17 -0800 Subject: [PATCH] feat(lua): add support for multiple execution threads in the Lua API Signed-off-by: Leonardo de Moura --- src/bindings/lua/leanlua_state.cpp | 195 ++++++++++++++++++++++------- src/bindings/lua/leanlua_state.h | 12 +- tests/lua/threads/th1.lua | 12 ++ 3 files changed, 169 insertions(+), 50 deletions(-) create mode 100644 tests/lua/threads/th1.lua diff --git a/src/bindings/lua/leanlua_state.cpp b/src/bindings/lua/leanlua_state.cpp index 358ed7874..3acd91267 100644 --- a/src/bindings/lua/leanlua_state.cpp +++ b/src/bindings/lua/leanlua_state.cpp @@ -6,6 +6,7 @@ Author: Leonardo de Moura */ #include #include +#include #include #include #include "util/debug.h" @@ -29,6 +30,7 @@ extern "C" void * lua_realloc(void *, void * q, size_t, size_t new_size) { retur namespace lean { static void open_state(lua_State * L); +static void open_thread(lua_State * L); static void copy_values(lua_State * src, int first, int last, lua_State * tgt) { for (int i = first; i <= last; i++) { @@ -77,18 +79,19 @@ struct leanlua_state::imp { if (m_state == nullptr) throw exception("fail to create Lua interpreter"); luaL_openlibs(m_state); - lean::open_name(m_state); - lean::open_mpz(m_state); - lean::open_mpq(m_state); - lean::open_options(m_state); - lean::open_sexpr(m_state); - lean::open_format(m_state); - lean::open_level(m_state); - lean::open_local_context(m_state); - lean::open_expr(m_state); - lean::open_context(m_state); - lean::open_environment(m_state); - lean::open_state(m_state); + open_name(m_state); + open_mpz(m_state); + open_mpq(m_state); + open_options(m_state); + open_sexpr(m_state); + open_format(m_state); + open_level(m_state); + open_local_context(m_state); + open_expr(m_state); + open_context(m_state); + open_environment(m_state); + open_state(m_state); + open_thread(m_state); dostring(g_leanlua_extra); } @@ -110,30 +113,6 @@ struct leanlua_state::imp { set_environment set(m_state, env); dostring(str); } - - int dostring(char const * str, lua_State * src, int first, int last) { - std::lock_guard lock(m_mutex); - - int sz_before = lua_gettop(m_state); - - int result = luaL_loadstring(m_state, str); - if (result) - throw lua_exception(lua_tostring(m_state, -1)); - - copy_values(src, first, last, m_state); - - result = lua_pcall(m_state, first > last ? 0 : last - first + 1, LUA_MULTRET, 0); - if (result) - throw lua_exception(lua_tostring(m_state, -1)); - - int sz_after = lua_gettop(m_state); - - if (sz_after > sz_before) { - copy_values(m_state, sz_before + 1, sz_after, src); - lua_pop(m_state, sz_after - sz_before); - } - return sz_after - sz_before; - } }; leanlua_state::leanlua_state(): @@ -155,10 +134,6 @@ void leanlua_state::dostring(char const * str, environment & env) { m_ptr->dostring(str, env); } -int leanlua_state::dostring(char const * str, lua_State * src, int first, int last) { - return m_ptr->dostring(str, src, first, last); -} - constexpr char const * state_mt = "state.mt"; bool is_state(lua_State * L, int idx) { @@ -187,8 +162,37 @@ static int state_gc(lua_State * L) { return 0; } -static int state_dostring(lua_State * L) { - return to_state(L, 1).dostring(luaL_checkstring(L, 2), L, 3, lua_gettop(L)); +int state_dostring(lua_State * L) { + auto S = to_state(L, 1).m_ptr; + char const * script = luaL_checkstring(L, 2); + int first = 3; + int last = lua_gettop(L); + std::lock_guard lock(S->m_mutex); + + int sz_before = lua_gettop(S->m_state); + + int result = luaL_loadstring(S->m_state, script); + if (result) + throw lua_exception(lua_tostring(S->m_state, -1)); + + copy_values(L, first, last, S->m_state); + + result = lua_pcall(S->m_state, first > last ? 0 : last - first + 1, LUA_MULTRET, 0); + if (result) + throw lua_exception(lua_tostring(S->m_state, -1)); + + int sz_after = lua_gettop(S->m_state); + + if (sz_after > sz_before) { + copy_values(S->m_state, sz_before + 1, sz_after, L); + lua_pop(S->m_state, sz_after - sz_before); + } + return sz_after - sz_before; +} + +static int state_pred(lua_State * L) { + lua_pushboolean(L, is_state(L, 1)); + return 1; } static const struct luaL_Reg state_m[] = { @@ -204,5 +208,112 @@ static void open_state(lua_State * L) { setfuncs(L, state_m, 0); set_global_function(L, "State"); + set_global_function(L, "is_State"); +} + +class leanlua_thread { + leanlua_state m_state; + int m_sz_before; + bool m_error; + std::string m_error_msg; + std::thread m_thread; +public: + leanlua_thread(leanlua_state const & st, int sz_before, int num_args): + m_state(st), + m_sz_before(sz_before), + m_error(false), + m_thread([=]() { + auto S = m_state.m_ptr; + std::lock_guard lock(S->m_mutex); + int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0); + if (result) { + m_error = true; + m_error_msg = lua_tostring(S->m_state, -1); + return; + } + }) { + } + + ~leanlua_thread() { + if (m_thread.joinable()) + m_thread.join(); + } + + int wait(lua_State * src) { + m_thread.join(); + if (m_error) + throw lua_exception(m_error_msg.c_str()); + auto S = m_state.m_ptr; + int sz_after = lua_gettop(S->m_state); + + if (sz_after > m_sz_before) { + copy_values(S->m_state, m_sz_before + 1, sz_after, src); + lua_pop(S->m_state, sz_after - m_sz_before); + } + return sz_after - m_sz_before; + } +}; + +constexpr char const * thread_mt = "thread.mt"; + +bool is_thread(lua_State * L, int idx) { + return testudata(L, idx, thread_mt); +} + +leanlua_thread & to_thread(lua_State * L, int idx) { + return *static_cast(luaL_checkudata(L, idx, thread_mt)); +} + +int mk_thread(lua_State * L) { + leanlua_state & st = to_state(L, 1); + char const * script = luaL_checkstring(L, 2); + int first = 3; + int last = lua_gettop(L); + int nargs = first > last ? 0 : last - first + 1; + int sz_before; + auto S = st.m_ptr; + { + std::lock_guard lock(S->m_mutex); + sz_before = lua_gettop(S->m_state); + int result = luaL_loadstring(S->m_state, script); + if (result) + throw lua_exception(lua_tostring(S->m_state, -1)); + copy_values(L, first, last, S->m_state); + } + void * mem = lua_newuserdata(L, sizeof(leanlua_thread)); + new (mem) leanlua_thread(st, sz_before, nargs); + luaL_getmetatable(L, thread_mt); + lua_setmetatable(L, -2); + return 1; +} + +static int thread_gc(lua_State * L) { + to_thread(L, 1).~leanlua_thread(); + return 0; +} + +static int thread_pred(lua_State * L) { + lua_pushboolean(L, is_thread(L, 1)); + return 1; +} + +int thread_wait(lua_State * L) { + return to_thread(L, 1).wait(L); +} + +static const struct luaL_Reg thread_m[] = { + {"__gc", thread_gc}, + {"wait", safe_function}, + {0, 0} +}; + +static void open_thread(lua_State * L) { + luaL_newmetatable(L, thread_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, thread_m, 0); + + set_global_function(L, "thread"); + set_global_function(L, "is_thread"); } } diff --git a/src/bindings/lua/leanlua_state.h b/src/bindings/lua/leanlua_state.h index 49b43690d..bac30a2b7 100644 --- a/src/bindings/lua/leanlua_state.h +++ b/src/bindings/lua/leanlua_state.h @@ -17,6 +17,10 @@ class environment; class leanlua_state { struct imp; std::shared_ptr m_ptr; + friend class leanlua_thread; + friend int state_dostring(lua_State *); + friend int mk_thread(lua_State *); + friend int thread_wait(lua_State *); public: leanlua_state(); ~leanlua_state(); @@ -38,13 +42,5 @@ public: The script \c str should not store a reference to the environment \c env. */ void dostring(char const * str, environment & env); - - /** - \brief Execute the given script, but copy the values at positions [first, last] from the stack of \c src. - The values are passed as arguments to the script \c str. - The values returned by the script \c str are copied back to the stack of \c src. - The result is the number of values returned by the script \c str. - */ - int dostring(char const * str, lua_State * src, int first, int last); }; } diff --git a/tests/lua/threads/th1.lua b/tests/lua/threads/th1.lua new file mode 100644 index 000000000..758fdaa64 --- /dev/null +++ b/tests/lua/threads/th1.lua @@ -0,0 +1,12 @@ +f = Const("f") +a = Const("a") + +S = State() +T = thread(S, [[ + t = ... + g = Const("g") + return g(t) +]], f(a)) + +r = T:wait() +print(r)