refactor(frontends/lua/leanlua_state): minimize the use of 'friend' directive

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-11-27 09:25:56 -08:00
parent 4c5ec53a44
commit 4c323093ac
2 changed files with 113 additions and 88 deletions

View file

@ -129,6 +129,15 @@ void leanlua_state::dostring(char const * str, environment & env, io_state & st)
m_ptr->dostring(str, env, st); m_ptr->dostring(str, env, st);
} }
std::recursive_mutex & leanlua_state::get_mutex() {
return m_ptr->m_mutex;
}
lua_State * leanlua_state::get_state() {
return m_ptr->m_state;
}
constexpr char const * state_mt = "luastate.mt"; constexpr char const * state_mt = "luastate.mt";
bool is_state(lua_State * L, int idx) { bool is_state(lua_State * L, int idx) {
@ -248,37 +257,35 @@ static void copy_values(lua_State * src, int first, int last, lua_State * tgt) {
} }
int state_dostring(lua_State * L) { int state_dostring(lua_State * L) {
auto S = to_state(L, 1).m_ptr; return to_state(L, 1).apply([&](lua_State * S) {
char const * script = luaL_checkstring(L, 2); char const * script = luaL_checkstring(L, 2);
int first = 3; int first = 3;
int last = lua_gettop(L); int last = lua_gettop(L);
std::lock_guard<std::recursive_mutex> lock(S->m_mutex); int sz_before = lua_gettop(S);
int status = luaL_loadstring(S, script);
if (status)
throw lua_exception(lua_tostring(S, -1));
int sz_before = lua_gettop(S->m_state); copy_values(L, first, last, S);
int result = luaL_loadstring(S->m_state, script); pcall(S, first > last ? 0 : last - first + 1, LUA_MULTRET, 0);
if (result)
throw lua_exception(lua_tostring(S->m_state, -1));
copy_values(L, first, last, S->m_state); int sz_after = lua_gettop(S);
pcall(S->m_state, first > last ? 0 : last - first + 1, LUA_MULTRET, 0); if (sz_after > sz_before) {
copy_values(S, sz_before + 1, sz_after, L);
int sz_after = lua_gettop(S->m_state); lua_pop(S, sz_after - sz_before);
}
if (sz_after > sz_before) { return sz_after - sz_before;
copy_values(S->m_state, sz_before + 1, sz_after, L); });
lua_pop(S->m_state, sz_after - sz_before);
}
return sz_after - sz_before;
} }
int state_set_global(lua_State * L) { int state_set_global(lua_State * L) {
auto S = to_state(L, 1).m_ptr; to_state(L, 1).apply([=](lua_State * S) {
char const * name = luaL_checkstring(L, 2); char const * name = luaL_checkstring(L, 2);
std::lock_guard<std::recursive_mutex> lock(S->m_mutex); copy_values(L, 3, 3, S);
copy_values(L, 3, 3, S->m_state); lua_setglobal(S, name);
lua_setglobal(S->m_state, name); });
return 0; return 0;
} }
@ -322,8 +329,9 @@ class data_channel {
std::condition_variable m_cv; std::condition_variable m_cv;
public: public:
data_channel() { data_channel() {
lua_State * channel = m_channel.m_ptr->m_state; m_channel.unguarded_apply([&](lua_State * channel) {
m_ini = lua_gettop(channel); m_ini = lua_gettop(channel);
});
} }
/** /**
@ -336,11 +344,12 @@ public:
if (last < first) if (last < first)
return; return;
std::lock_guard<std::mutex> lock(m_mutex); std::lock_guard<std::mutex> lock(m_mutex);
lua_State * channel = m_channel.m_ptr->m_state; m_channel.unguarded_apply([&](lua_State * channel) {
bool was_empty = lua_gettop(channel) == m_ini; bool was_empty = lua_gettop(channel) == m_ini;
copy_values(src, first, last, channel); copy_values(src, first, last, channel);
if (was_empty) if (was_empty)
m_cv.notify_one(); m_cv.notify_one();
});
} }
/** /**
@ -349,32 +358,33 @@ public:
*/ */
int read(lua_State * tgt, int i) { int read(lua_State * tgt, int i) {
std::unique_lock<std::mutex> lock(m_mutex); std::unique_lock<std::mutex> lock(m_mutex);
lua_State * channel = m_channel.m_ptr->m_state; return m_channel.unguarded_apply([&](lua_State * channel) {
if (i > 0) { if (i > 0) {
// i is the position of the timeout argument // i is the position of the timeout argument
std::chrono::milliseconds dura(luaL_checkinteger(tgt, i)); std::chrono::milliseconds dura(luaL_checkinteger(tgt, i));
if (lua_gettop(channel) == m_ini) if (lua_gettop(channel) == m_ini)
m_cv.wait_for(lock, dura); m_cv.wait_for(lock, dura);
if (lua_gettop(channel) == m_ini) { if (lua_gettop(channel) == m_ini) {
// timeout... // timeout...
lua_pushboolean(tgt, false); lua_pushboolean(tgt, false);
lua_pushnil(tgt); lua_pushnil(tgt);
return 2; return 2;
} else { } else {
lua_pushboolean(tgt, true); lua_pushboolean(tgt, true);
copy_values(channel, m_ini + 1, m_ini + 1, tgt); copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1); lua_remove(channel, m_ini + 1);
return 2; return 2;
} }
} else { } else {
while (lua_gettop(channel) == m_ini) { while (lua_gettop(channel) == m_ini) {
check_interrupted(); check_interrupted();
m_cv.wait_for(lock, g_small_delay); m_cv.wait_for(lock, g_small_delay);
} }
copy_values(channel, m_ini + 1, m_ini + 1, tgt); copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1); lua_remove(channel, m_ini + 1);
return 1; return 1;
} }
});
} }
}; };
@ -424,14 +434,13 @@ public:
m_thread([=]() { m_thread([=]() {
m_in_channel_addr.store(&g_in_channel); m_in_channel_addr.store(&g_in_channel);
m_out_channel_addr.store(&g_out_channel); m_out_channel_addr.store(&g_out_channel);
auto S = m_state.m_ptr; m_state.apply([&](lua_State * S) {
std::lock_guard<std::recursive_mutex> lock(S->m_mutex); int result = lua_pcall(S, num_args, LUA_MULTRET, 0);
int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0); if (result) {
if (result) { m_error = true;
m_error = true; m_error_msg = lua_tostring(S, -1);
m_error_msg = lua_tostring(S->m_state, -1); }
return; });
}
}) { }) {
} }
@ -444,14 +453,14 @@ public:
m_thread.join(); m_thread.join();
if (m_error) if (m_error)
throw lua_exception(m_error_msg.c_str()); throw lua_exception(m_error_msg.c_str());
auto S = m_state.m_ptr; return m_state.apply([&](lua_State * S) {
int sz_after = lua_gettop(S->m_state); int sz_after = lua_gettop(S);
if (sz_after > m_sz_before) {
if (sz_after > m_sz_before) { copy_values(S, m_sz_before + 1, sz_after, src);
copy_values(S->m_state, m_sz_before + 1, sz_after, src); lua_pop(S, sz_after - m_sz_before);
lua_pop(S->m_state, sz_after - m_sz_before); }
} return sz_after - m_sz_before;
return sz_after - m_sz_before; });
} }
void request_interrupt() { void request_interrupt() {
@ -500,15 +509,13 @@ int mk_thread(lua_State * L) {
int last = lua_gettop(L); int last = lua_gettop(L);
int nargs = first > last ? 0 : last - first + 1; int nargs = first > last ? 0 : last - first + 1;
int sz_before; int sz_before;
auto S = st.m_ptr; st.apply([&](lua_State * S) {
{ sz_before = lua_gettop(S);
std::lock_guard<std::recursive_mutex> lock(S->m_mutex); int result = luaL_loadstring(S, script);
sz_before = lua_gettop(S->m_state); if (result)
int result = luaL_loadstring(S->m_state, script); throw lua_exception(lua_tostring(S, -1));
if (result) copy_values(L, first, last, S);
throw lua_exception(lua_tostring(S->m_state, -1)); });
copy_values(L, first, last, S->m_state);
}
void * mem = lua_newuserdata(L, sizeof(leanlua_thread)); void * mem = lua_newuserdata(L, sizeof(leanlua_thread));
new (mem) leanlua_thread(st, sz_before, nargs); new (mem) leanlua_thread(st, sz_before, nargs);
luaL_getmetatable(L, thread_mt); luaL_getmetatable(L, thread_mt);

View file

@ -6,6 +6,7 @@ Author: Leonardo de Moura
*/ */
#pragma once #pragma once
#include <memory> #include <memory>
#include <mutex>
#include <lua.hpp> #include <lua.hpp>
#include "util/lua_exception.h" #include "util/lua_exception.h"
#include "library/script_evaluator.h" #include "library/script_evaluator.h"
@ -22,13 +23,9 @@ public:
private: private:
std::shared_ptr<imp> m_ptr; std::shared_ptr<imp> m_ptr;
leanlua_state(std::weak_ptr<imp> const & ptr); leanlua_state(std::weak_ptr<imp> const & ptr);
friend class leanlua_thread;
friend class data_channel;
friend int state_dostring(lua_State * L);
friend int state_set_global(lua_State * L);
friend int mk_thread(lua_State * L);
friend int thread_wait(lua_State * L);
friend leanlua_state to_leanlua_state(lua_State * L); friend leanlua_state to_leanlua_state(lua_State * L);
std::recursive_mutex & get_mutex();
lua_State * get_state();
public: public:
leanlua_state(); leanlua_state();
virtual ~leanlua_state(); virtual ~leanlua_state();
@ -50,6 +47,27 @@ public:
The script \c str should not store a reference to the environment \c env. The script \c str should not store a reference to the environment \c env.
*/ */
virtual void dostring(char const * str, environment & env, io_state & st); virtual void dostring(char const * str, environment & env, io_state & st);
/**
\brief Execute \c f in the using the internal Lua State.
*/
template<typename F>
typename std::result_of<F(lua_State * L)>::type apply(F && f) {
std::lock_guard<std::recursive_mutex> lock(get_mutex());
return f(get_state());
}
/**
\brief Similar to \c apply, but a lock is not used to guarantee
exclusive access to the lua_State object.
\warning It is the caller resposability to guarantee that the object is not being
concurrently accessed.
*/
template<typename F>
typename std::result_of<F(lua_State * L)>::type unguarded_apply(F && f) {
return f(get_state());
}
}; };
/** /**
\brief Return a reference to the leanlua_state object that is wrapping \c L. \brief Return a reference to the leanlua_state object that is wrapping \c L.