refactor(frontends/lua/leanlua_state): minimize the use of 'friend' directive

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-11-27 09:25:56 -08:00
parent 4c5ec53a44
commit 4c323093ac
2 changed files with 113 additions and 88 deletions

View file

@ -129,6 +129,15 @@ void leanlua_state::dostring(char const * str, environment & env, io_state & st)
m_ptr->dostring(str, env, st);
}
std::recursive_mutex & leanlua_state::get_mutex() {
return m_ptr->m_mutex;
}
lua_State * leanlua_state::get_state() {
return m_ptr->m_state;
}
constexpr char const * state_mt = "luastate.mt";
bool is_state(lua_State * L, int idx) {
@ -248,37 +257,35 @@ static void copy_values(lua_State * src, int first, int last, lua_State * tgt) {
}
int state_dostring(lua_State * L) {
auto S = to_state(L, 1).m_ptr;
char const * script = luaL_checkstring(L, 2);
int first = 3;
int last = lua_gettop(L);
std::lock_guard<std::recursive_mutex> lock(S->m_mutex);
return to_state(L, 1).apply([&](lua_State * S) {
char const * script = luaL_checkstring(L, 2);
int first = 3;
int last = lua_gettop(L);
int sz_before = lua_gettop(S);
int status = luaL_loadstring(S, script);
if (status)
throw lua_exception(lua_tostring(S, -1));
int sz_before = lua_gettop(S->m_state);
copy_values(L, first, last, S);
int result = luaL_loadstring(S->m_state, script);
if (result)
throw lua_exception(lua_tostring(S->m_state, -1));
pcall(S, first > last ? 0 : last - first + 1, LUA_MULTRET, 0);
copy_values(L, first, last, S->m_state);
int sz_after = lua_gettop(S);
pcall(S->m_state, first > last ? 0 : last - first + 1, LUA_MULTRET, 0);
int sz_after = lua_gettop(S->m_state);
if (sz_after > sz_before) {
copy_values(S->m_state, sz_before + 1, sz_after, L);
lua_pop(S->m_state, sz_after - sz_before);
}
return sz_after - sz_before;
if (sz_after > sz_before) {
copy_values(S, sz_before + 1, sz_after, L);
lua_pop(S, sz_after - sz_before);
}
return sz_after - sz_before;
});
}
int state_set_global(lua_State * L) {
auto S = to_state(L, 1).m_ptr;
char const * name = luaL_checkstring(L, 2);
std::lock_guard<std::recursive_mutex> lock(S->m_mutex);
copy_values(L, 3, 3, S->m_state);
lua_setglobal(S->m_state, name);
to_state(L, 1).apply([=](lua_State * S) {
char const * name = luaL_checkstring(L, 2);
copy_values(L, 3, 3, S);
lua_setglobal(S, name);
});
return 0;
}
@ -322,8 +329,9 @@ class data_channel {
std::condition_variable m_cv;
public:
data_channel() {
lua_State * channel = m_channel.m_ptr->m_state;
m_ini = lua_gettop(channel);
m_channel.unguarded_apply([&](lua_State * channel) {
m_ini = lua_gettop(channel);
});
}
/**
@ -336,11 +344,12 @@ public:
if (last < first)
return;
std::lock_guard<std::mutex> lock(m_mutex);
lua_State * channel = m_channel.m_ptr->m_state;
bool was_empty = lua_gettop(channel) == m_ini;
copy_values(src, first, last, channel);
if (was_empty)
m_cv.notify_one();
m_channel.unguarded_apply([&](lua_State * channel) {
bool was_empty = lua_gettop(channel) == m_ini;
copy_values(src, first, last, channel);
if (was_empty)
m_cv.notify_one();
});
}
/**
@ -349,32 +358,33 @@ public:
*/
int read(lua_State * tgt, int i) {
std::unique_lock<std::mutex> lock(m_mutex);
lua_State * channel = m_channel.m_ptr->m_state;
if (i > 0) {
// i is the position of the timeout argument
std::chrono::milliseconds dura(luaL_checkinteger(tgt, i));
if (lua_gettop(channel) == m_ini)
m_cv.wait_for(lock, dura);
if (lua_gettop(channel) == m_ini) {
// timeout...
lua_pushboolean(tgt, false);
lua_pushnil(tgt);
return 2;
} else {
lua_pushboolean(tgt, true);
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1);
return 2;
}
} else {
while (lua_gettop(channel) == m_ini) {
check_interrupted();
m_cv.wait_for(lock, g_small_delay);
}
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1);
return 1;
}
return m_channel.unguarded_apply([&](lua_State * channel) {
if (i > 0) {
// i is the position of the timeout argument
std::chrono::milliseconds dura(luaL_checkinteger(tgt, i));
if (lua_gettop(channel) == m_ini)
m_cv.wait_for(lock, dura);
if (lua_gettop(channel) == m_ini) {
// timeout...
lua_pushboolean(tgt, false);
lua_pushnil(tgt);
return 2;
} else {
lua_pushboolean(tgt, true);
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1);
return 2;
}
} else {
while (lua_gettop(channel) == m_ini) {
check_interrupted();
m_cv.wait_for(lock, g_small_delay);
}
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1);
return 1;
}
});
}
};
@ -424,14 +434,13 @@ public:
m_thread([=]() {
m_in_channel_addr.store(&g_in_channel);
m_out_channel_addr.store(&g_out_channel);
auto S = m_state.m_ptr;
std::lock_guard<std::recursive_mutex> lock(S->m_mutex);
int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0);
if (result) {
m_error = true;
m_error_msg = lua_tostring(S->m_state, -1);
return;
}
m_state.apply([&](lua_State * S) {
int result = lua_pcall(S, num_args, LUA_MULTRET, 0);
if (result) {
m_error = true;
m_error_msg = lua_tostring(S, -1);
}
});
}) {
}
@ -444,14 +453,14 @@ public:
m_thread.join();
if (m_error)
throw lua_exception(m_error_msg.c_str());
auto S = m_state.m_ptr;
int sz_after = lua_gettop(S->m_state);
if (sz_after > m_sz_before) {
copy_values(S->m_state, m_sz_before + 1, sz_after, src);
lua_pop(S->m_state, sz_after - m_sz_before);
}
return sz_after - m_sz_before;
return m_state.apply([&](lua_State * S) {
int sz_after = lua_gettop(S);
if (sz_after > m_sz_before) {
copy_values(S, m_sz_before + 1, sz_after, src);
lua_pop(S, sz_after - m_sz_before);
}
return sz_after - m_sz_before;
});
}
void request_interrupt() {
@ -500,15 +509,13 @@ int mk_thread(lua_State * L) {
int last = lua_gettop(L);
int nargs = first > last ? 0 : last - first + 1;
int sz_before;
auto S = st.m_ptr;
{
std::lock_guard<std::recursive_mutex> lock(S->m_mutex);
sz_before = lua_gettop(S->m_state);
int result = luaL_loadstring(S->m_state, script);
if (result)
throw lua_exception(lua_tostring(S->m_state, -1));
copy_values(L, first, last, S->m_state);
}
st.apply([&](lua_State * S) {
sz_before = lua_gettop(S);
int result = luaL_loadstring(S, script);
if (result)
throw lua_exception(lua_tostring(S, -1));
copy_values(L, first, last, S);
});
void * mem = lua_newuserdata(L, sizeof(leanlua_thread));
new (mem) leanlua_thread(st, sz_before, nargs);
luaL_getmetatable(L, thread_mt);

View file

@ -6,6 +6,7 @@ Author: Leonardo de Moura
*/
#pragma once
#include <memory>
#include <mutex>
#include <lua.hpp>
#include "util/lua_exception.h"
#include "library/script_evaluator.h"
@ -22,13 +23,9 @@ public:
private:
std::shared_ptr<imp> m_ptr;
leanlua_state(std::weak_ptr<imp> const & ptr);
friend class leanlua_thread;
friend class data_channel;
friend int state_dostring(lua_State * L);
friend int state_set_global(lua_State * L);
friend int mk_thread(lua_State * L);
friend int thread_wait(lua_State * L);
friend leanlua_state to_leanlua_state(lua_State * L);
std::recursive_mutex & get_mutex();
lua_State * get_state();
public:
leanlua_state();
virtual ~leanlua_state();
@ -50,6 +47,27 @@ public:
The script \c str should not store a reference to the environment \c env.
*/
virtual void dostring(char const * str, environment & env, io_state & st);
/**
\brief Execute \c f in the using the internal Lua State.
*/
template<typename F>
typename std::result_of<F(lua_State * L)>::type apply(F && f) {
std::lock_guard<std::recursive_mutex> lock(get_mutex());
return f(get_state());
}
/**
\brief Similar to \c apply, but a lock is not used to guarantee
exclusive access to the lua_State object.
\warning It is the caller resposability to guarantee that the object is not being
concurrently accessed.
*/
template<typename F>
typename std::result_of<F(lua_State * L)>::type unguarded_apply(F && f) {
return f(get_state());
}
};
/**
\brief Return a reference to the leanlua_state object that is wrapping \c L.