diff --git a/src/frontends/lua/leanlua_state.cpp b/src/frontends/lua/leanlua_state.cpp index 4332cd474..60cdbd836 100644 --- a/src/frontends/lua/leanlua_state.cpp +++ b/src/frontends/lua/leanlua_state.cpp @@ -129,6 +129,15 @@ void leanlua_state::dostring(char const * str, environment & env, io_state & st) m_ptr->dostring(str, env, st); } +std::recursive_mutex & leanlua_state::get_mutex() { + return m_ptr->m_mutex; +} + +lua_State * leanlua_state::get_state() { + return m_ptr->m_state; +} + + constexpr char const * state_mt = "luastate.mt"; bool is_state(lua_State * L, int idx) { @@ -248,37 +257,35 @@ static void copy_values(lua_State * src, int first, int last, lua_State * tgt) { } int state_dostring(lua_State * L) { - auto S = to_state(L, 1).m_ptr; - char const * script = luaL_checkstring(L, 2); - int first = 3; - int last = lua_gettop(L); - std::lock_guard lock(S->m_mutex); + return to_state(L, 1).apply([&](lua_State * S) { + char const * script = luaL_checkstring(L, 2); + int first = 3; + int last = lua_gettop(L); + int sz_before = lua_gettop(S); + int status = luaL_loadstring(S, script); + if (status) + throw lua_exception(lua_tostring(S, -1)); - int sz_before = lua_gettop(S->m_state); + copy_values(L, first, last, S); - int result = luaL_loadstring(S->m_state, script); - if (result) - throw lua_exception(lua_tostring(S->m_state, -1)); + pcall(S, first > last ? 0 : last - first + 1, LUA_MULTRET, 0); - copy_values(L, first, last, S->m_state); + int sz_after = lua_gettop(S); - pcall(S->m_state, first > last ? 0 : last - first + 1, LUA_MULTRET, 0); - - int sz_after = lua_gettop(S->m_state); - - if (sz_after > sz_before) { - copy_values(S->m_state, sz_before + 1, sz_after, L); - lua_pop(S->m_state, sz_after - sz_before); - } - return sz_after - sz_before; + if (sz_after > sz_before) { + copy_values(S, sz_before + 1, sz_after, L); + lua_pop(S, sz_after - sz_before); + } + return sz_after - sz_before; + }); } int state_set_global(lua_State * L) { - auto S = to_state(L, 1).m_ptr; - char const * name = luaL_checkstring(L, 2); - std::lock_guard lock(S->m_mutex); - copy_values(L, 3, 3, S->m_state); - lua_setglobal(S->m_state, name); + to_state(L, 1).apply([=](lua_State * S) { + char const * name = luaL_checkstring(L, 2); + copy_values(L, 3, 3, S); + lua_setglobal(S, name); + }); return 0; } @@ -322,8 +329,9 @@ class data_channel { std::condition_variable m_cv; public: data_channel() { - lua_State * channel = m_channel.m_ptr->m_state; - m_ini = lua_gettop(channel); + m_channel.unguarded_apply([&](lua_State * channel) { + m_ini = lua_gettop(channel); + }); } /** @@ -336,11 +344,12 @@ public: if (last < first) return; std::lock_guard lock(m_mutex); - lua_State * channel = m_channel.m_ptr->m_state; - bool was_empty = lua_gettop(channel) == m_ini; - copy_values(src, first, last, channel); - if (was_empty) - m_cv.notify_one(); + m_channel.unguarded_apply([&](lua_State * channel) { + bool was_empty = lua_gettop(channel) == m_ini; + copy_values(src, first, last, channel); + if (was_empty) + m_cv.notify_one(); + }); } /** @@ -349,32 +358,33 @@ public: */ int read(lua_State * tgt, int i) { std::unique_lock lock(m_mutex); - lua_State * channel = m_channel.m_ptr->m_state; - if (i > 0) { - // i is the position of the timeout argument - std::chrono::milliseconds dura(luaL_checkinteger(tgt, i)); - if (lua_gettop(channel) == m_ini) - m_cv.wait_for(lock, dura); - if (lua_gettop(channel) == m_ini) { - // timeout... - lua_pushboolean(tgt, false); - lua_pushnil(tgt); - return 2; - } else { - lua_pushboolean(tgt, true); - copy_values(channel, m_ini + 1, m_ini + 1, tgt); - lua_remove(channel, m_ini + 1); - return 2; - } - } else { - while (lua_gettop(channel) == m_ini) { - check_interrupted(); - m_cv.wait_for(lock, g_small_delay); - } - copy_values(channel, m_ini + 1, m_ini + 1, tgt); - lua_remove(channel, m_ini + 1); - return 1; - } + return m_channel.unguarded_apply([&](lua_State * channel) { + if (i > 0) { + // i is the position of the timeout argument + std::chrono::milliseconds dura(luaL_checkinteger(tgt, i)); + if (lua_gettop(channel) == m_ini) + m_cv.wait_for(lock, dura); + if (lua_gettop(channel) == m_ini) { + // timeout... + lua_pushboolean(tgt, false); + lua_pushnil(tgt); + return 2; + } else { + lua_pushboolean(tgt, true); + copy_values(channel, m_ini + 1, m_ini + 1, tgt); + lua_remove(channel, m_ini + 1); + return 2; + } + } else { + while (lua_gettop(channel) == m_ini) { + check_interrupted(); + m_cv.wait_for(lock, g_small_delay); + } + copy_values(channel, m_ini + 1, m_ini + 1, tgt); + lua_remove(channel, m_ini + 1); + return 1; + } + }); } }; @@ -424,14 +434,13 @@ public: m_thread([=]() { m_in_channel_addr.store(&g_in_channel); m_out_channel_addr.store(&g_out_channel); - auto S = m_state.m_ptr; - std::lock_guard lock(S->m_mutex); - int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0); - if (result) { - m_error = true; - m_error_msg = lua_tostring(S->m_state, -1); - return; - } + m_state.apply([&](lua_State * S) { + int result = lua_pcall(S, num_args, LUA_MULTRET, 0); + if (result) { + m_error = true; + m_error_msg = lua_tostring(S, -1); + } + }); }) { } @@ -444,14 +453,14 @@ public: m_thread.join(); if (m_error) throw lua_exception(m_error_msg.c_str()); - auto S = m_state.m_ptr; - int sz_after = lua_gettop(S->m_state); - - if (sz_after > m_sz_before) { - copy_values(S->m_state, m_sz_before + 1, sz_after, src); - lua_pop(S->m_state, sz_after - m_sz_before); - } - return sz_after - m_sz_before; + return m_state.apply([&](lua_State * S) { + int sz_after = lua_gettop(S); + if (sz_after > m_sz_before) { + copy_values(S, m_sz_before + 1, sz_after, src); + lua_pop(S, sz_after - m_sz_before); + } + return sz_after - m_sz_before; + }); } void request_interrupt() { @@ -500,15 +509,13 @@ int mk_thread(lua_State * L) { int last = lua_gettop(L); int nargs = first > last ? 0 : last - first + 1; int sz_before; - auto S = st.m_ptr; - { - std::lock_guard lock(S->m_mutex); - sz_before = lua_gettop(S->m_state); - int result = luaL_loadstring(S->m_state, script); - if (result) - throw lua_exception(lua_tostring(S->m_state, -1)); - copy_values(L, first, last, S->m_state); - } + st.apply([&](lua_State * S) { + sz_before = lua_gettop(S); + int result = luaL_loadstring(S, script); + if (result) + throw lua_exception(lua_tostring(S, -1)); + copy_values(L, first, last, S); + }); void * mem = lua_newuserdata(L, sizeof(leanlua_thread)); new (mem) leanlua_thread(st, sz_before, nargs); luaL_getmetatable(L, thread_mt); diff --git a/src/frontends/lua/leanlua_state.h b/src/frontends/lua/leanlua_state.h index a98c9ccb1..7c877ba22 100644 --- a/src/frontends/lua/leanlua_state.h +++ b/src/frontends/lua/leanlua_state.h @@ -6,6 +6,7 @@ Author: Leonardo de Moura */ #pragma once #include +#include #include #include "util/lua_exception.h" #include "library/script_evaluator.h" @@ -22,13 +23,9 @@ public: private: std::shared_ptr m_ptr; leanlua_state(std::weak_ptr const & ptr); - friend class leanlua_thread; - friend class data_channel; - friend int state_dostring(lua_State * L); - friend int state_set_global(lua_State * L); - friend int mk_thread(lua_State * L); - friend int thread_wait(lua_State * L); friend leanlua_state to_leanlua_state(lua_State * L); + std::recursive_mutex & get_mutex(); + lua_State * get_state(); public: leanlua_state(); virtual ~leanlua_state(); @@ -50,6 +47,27 @@ public: The script \c str should not store a reference to the environment \c env. */ virtual void dostring(char const * str, environment & env, io_state & st); + + /** + \brief Execute \c f in the using the internal Lua State. + */ + template + typename std::result_of::type apply(F && f) { + std::lock_guard lock(get_mutex()); + return f(get_state()); + } + + /** + \brief Similar to \c apply, but a lock is not used to guarantee + exclusive access to the lua_State object. + + \warning It is the caller resposability to guarantee that the object is not being + concurrently accessed. + */ + template + typename std::result_of::type unguarded_apply(F && f) { + return f(get_state()); + } }; /** \brief Return a reference to the leanlua_state object that is wrapping \c L.