feat(lua): communication channels for threads

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-11-14 21:10:46 -08:00
parent bd1e9c7548
commit b5dcb93550
3 changed files with 193 additions and 10 deletions

View file

@ -340,18 +340,123 @@ static void open_state(lua_State * L) {
SET_GLOBAL_FUN(state_pred, "is_State"); SET_GLOBAL_FUN(state_pred, "is_State");
} }
// TODO(Leo): allow the user to change it?
#define SMALL_DELAY 10 // in ms
std::chrono::milliseconds g_small_delay(SMALL_DELAY);
/**
\brief Channel for communicating with thread objects in the Lua API
*/
class data_channel {
// We use a lua_State to implement the channel. This is quite hackish,
// but it is a convenient storage for Lua objects sent from one state to
// another.
leanlua_state m_channel;
int m_ini;
std::mutex m_mutex;
std::condition_variable m_cv;
public:
data_channel() {
lua_State * channel = m_channel.m_ptr->m_state;
m_ini = lua_gettop(channel);
}
/**
\brief Copy elements from positions [first, last] from src stack
to the channel.
*/
void write(lua_State * src, int first, int last) {
// write the object on the top of the stack of src to the table
// on m_channel.
if (last < first)
return;
std::lock_guard<std::mutex> lock(m_mutex);
lua_State * channel = m_channel.m_ptr->m_state;
bool was_empty = lua_gettop(channel) == m_ini;
copy_values(src, first, last, channel);
if (was_empty)
m_cv.notify_one();
}
/**
\brief Retrieve one element from the channel. It will block
the execution of \c tgt if the channel is empty.
*/
int read(lua_State * tgt, int i) {
std::unique_lock<std::mutex> lock(m_mutex);
lua_State * channel = m_channel.m_ptr->m_state;
if (i > 0) {
// i is the position of the timeout argument
std::chrono::milliseconds dura(luaL_checkinteger(tgt, i));
if (lua_gettop(channel) == m_ini)
m_cv.wait_for(lock, dura);
if (lua_gettop(channel) == m_ini) {
// timeout...
lua_pushboolean(tgt, false);
lua_pushnil(tgt);
return 2;
} else {
lua_pushboolean(tgt, true);
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1);
return 2;
}
} else {
while (lua_gettop(channel) == m_ini) {
check_interrupted();
m_cv.wait_for(lock, g_small_delay);
}
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1);
return 1;
}
}
};
/**
\brief We want the channels to be lazily created.
*/
class data_channel_ref {
std::unique_ptr<data_channel> m_channel;
std::mutex m_mutex;
public:
data_channel & get() {
std::lock_guard<std::mutex> lock(m_mutex);
if (!m_channel)
m_channel.reset(new data_channel());
lean_assert(m_channel);
return *m_channel;
}
};
data_channel_ref g_in_channel;
data_channel_ref g_out_channel;
int channel_read(lua_State * L) {
return g_in_channel.get().read(L, lua_gettop(L));
}
int channel_write(lua_State * L) {
g_out_channel.get().write(L, 1, lua_gettop(L));
return 0;
}
class leanlua_thread { class leanlua_thread {
leanlua_state m_state; leanlua_state m_state;
int m_sz_before; int m_sz_before;
bool m_error; bool m_error;
std::string m_error_msg; std::string m_error_msg;
interruptible_thread m_thread; interruptible_thread m_thread;
std::atomic<data_channel_ref *> m_in_channel_addr;
std::atomic<data_channel_ref *> m_out_channel_addr;
public: public:
leanlua_thread(leanlua_state const & st, int sz_before, int num_args): leanlua_thread(leanlua_state const & st, int sz_before, int num_args):
m_state(st), m_state(st),
m_sz_before(sz_before), m_sz_before(sz_before),
m_error(false), m_error(false),
m_thread([=]() { m_thread([=]() {
m_in_channel_addr.store(&g_in_channel);
m_out_channel_addr.store(&g_out_channel);
auto S = m_state.m_ptr; auto S = m_state.m_ptr;
std::lock_guard<std::mutex> lock(S->m_mutex); std::lock_guard<std::mutex> lock(S->m_mutex);
int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0); int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0);
@ -382,9 +487,34 @@ public:
return sz_after - m_sz_before; return sz_after - m_sz_before;
} }
int request_interrupt(lua_State * src) { void request_interrupt() {
lua_pushboolean(src, m_thread.request_interrupt()); while (!m_thread.request_interrupt()) {
return 1; check_interrupted();
std::this_thread::sleep_for(g_small_delay);
}
}
void write(lua_State * src, int first, int last) {
while (!m_in_channel_addr) {
check_interrupted();
std::this_thread::sleep_for(g_small_delay);
}
data_channel & in = m_in_channel_addr.load()->get();
in.write(src, first, last);
}
int read(lua_State * src) {
if (!m_out_channel_addr) {
check_interrupted();
std::this_thread::sleep_for(g_small_delay);
}
data_channel & out = m_out_channel_addr.load()->get();
int nargs = lua_gettop(src);
return out.read(src, nargs == 1 ? 0 : 2);
}
bool started() {
return m_in_channel_addr && m_out_channel_addr;
} }
}; };
@ -432,10 +562,21 @@ static int thread_pred(lua_State * L) {
return 1; return 1;
} }
static int thread_interrupt(lua_State * L) { static int thread_write(lua_State * L) {
return to_thread(L, 1).request_interrupt(L); to_thread(L, 1).write(L, 2, lua_gettop(L));
return 0;
} }
static int thread_read(lua_State * L) {
return to_thread(L, 1).read(L);
}
static int thread_interrupt(lua_State * L) {
to_thread(L, 1).request_interrupt();
return 0;
}
int thread_wait(lua_State * L) { int thread_wait(lua_State * L) {
return to_thread(L, 1).wait(L); return to_thread(L, 1).wait(L);
} }
@ -444,6 +585,8 @@ static const struct luaL_Reg thread_m[] = {
{"__gc", thread_gc}, {"__gc", thread_gc},
{"wait", safe_function<thread_wait>}, {"wait", safe_function<thread_wait>},
{"interrupt", safe_function<thread_interrupt>}, {"interrupt", safe_function<thread_interrupt>},
{"write", safe_function<thread_write>},
{"read", safe_function<thread_read>},
{0, 0} {0, 0}
}; };
@ -471,5 +614,7 @@ static int sleep(lua_State * L) {
static void open_interrupt(lua_State * L) { static void open_interrupt(lua_State * L) {
SET_GLOBAL_FUN(check_interrupted, "check_interrupted"); SET_GLOBAL_FUN(check_interrupted, "check_interrupted");
SET_GLOBAL_FUN(sleep, "sleep"); SET_GLOBAL_FUN(sleep, "sleep");
SET_GLOBAL_FUN(channel_read, "read");
SET_GLOBAL_FUN(channel_write, "write");
} }
} }

View file

@ -19,6 +19,7 @@ class leanlua_state {
struct imp; struct imp;
std::shared_ptr<imp> m_ptr; std::shared_ptr<imp> m_ptr;
friend class leanlua_thread; friend class leanlua_thread;
friend class data_channel;
friend int state_dostring(lua_State * L); friend int state_dostring(lua_State * L);
friend int state_set_global(lua_State * L); friend int state_set_global(lua_State * L);
friend int mk_thread(lua_State * L); friend int mk_thread(lua_State * L);

37
tests/lua/th3.lua Normal file
View file

@ -0,0 +1,37 @@
S = State()
T = thread(S, [[
print("starting thread...")
pcall(function()
while true do
check_interrupted() -- check if thread was interrupted
local ok, val = read(10) -- 10 ms timeout
if ok then
print("thread received:", val)
write(val + 10, val - 10) -- send result back to main thread
end
end
end)
]])
for i = 1, 10 do
T:write(10 * i)
local r1 = T:read()
local r2 = T:read()
print("main received: ", r1, r2)
end
T:interrupt()
print("done")
-- Channels are quite flexible, we can send closure over the channel
T = thread(S, [[
for i = 1, 10 do
local f = read()
-- send back the result of f(i)
write(f(i))
end
]])
for i = 1, 10 do
T:write(function (x) return x + i end)
print(T:read())
end