feat(lua): communication channels for threads

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-11-14 21:10:46 -08:00
parent bd1e9c7548
commit b5dcb93550
3 changed files with 193 additions and 10 deletions

View file

@ -340,18 +340,123 @@ static void open_state(lua_State * L) {
SET_GLOBAL_FUN(state_pred, "is_State");
}
// TODO(Leo): allow the user to change it?
#define SMALL_DELAY 10 // in ms
std::chrono::milliseconds g_small_delay(SMALL_DELAY);
/**
\brief Channel for communicating with thread objects in the Lua API
*/
class data_channel {
// We use a lua_State to implement the channel. This is quite hackish,
// but it is a convenient storage for Lua objects sent from one state to
// another.
leanlua_state m_channel;
int m_ini;
std::mutex m_mutex;
std::condition_variable m_cv;
public:
data_channel() {
lua_State * channel = m_channel.m_ptr->m_state;
m_ini = lua_gettop(channel);
}
/**
\brief Copy elements from positions [first, last] from src stack
to the channel.
*/
void write(lua_State * src, int first, int last) {
// write the object on the top of the stack of src to the table
// on m_channel.
if (last < first)
return;
std::lock_guard<std::mutex> lock(m_mutex);
lua_State * channel = m_channel.m_ptr->m_state;
bool was_empty = lua_gettop(channel) == m_ini;
copy_values(src, first, last, channel);
if (was_empty)
m_cv.notify_one();
}
/**
\brief Retrieve one element from the channel. It will block
the execution of \c tgt if the channel is empty.
*/
int read(lua_State * tgt, int i) {
std::unique_lock<std::mutex> lock(m_mutex);
lua_State * channel = m_channel.m_ptr->m_state;
if (i > 0) {
// i is the position of the timeout argument
std::chrono::milliseconds dura(luaL_checkinteger(tgt, i));
if (lua_gettop(channel) == m_ini)
m_cv.wait_for(lock, dura);
if (lua_gettop(channel) == m_ini) {
// timeout...
lua_pushboolean(tgt, false);
lua_pushnil(tgt);
return 2;
} else {
lua_pushboolean(tgt, true);
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1);
return 2;
}
} else {
while (lua_gettop(channel) == m_ini) {
check_interrupted();
m_cv.wait_for(lock, g_small_delay);
}
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1);
return 1;
}
}
};
/**
\brief We want the channels to be lazily created.
*/
class data_channel_ref {
std::unique_ptr<data_channel> m_channel;
std::mutex m_mutex;
public:
data_channel & get() {
std::lock_guard<std::mutex> lock(m_mutex);
if (!m_channel)
m_channel.reset(new data_channel());
lean_assert(m_channel);
return *m_channel;
}
};
data_channel_ref g_in_channel;
data_channel_ref g_out_channel;
int channel_read(lua_State * L) {
return g_in_channel.get().read(L, lua_gettop(L));
}
int channel_write(lua_State * L) {
g_out_channel.get().write(L, 1, lua_gettop(L));
return 0;
}
class leanlua_thread {
leanlua_state m_state;
int m_sz_before;
bool m_error;
std::string m_error_msg;
interruptible_thread m_thread;
leanlua_state m_state;
int m_sz_before;
bool m_error;
std::string m_error_msg;
interruptible_thread m_thread;
std::atomic<data_channel_ref *> m_in_channel_addr;
std::atomic<data_channel_ref *> m_out_channel_addr;
public:
leanlua_thread(leanlua_state const & st, int sz_before, int num_args):
m_state(st),
m_sz_before(sz_before),
m_error(false),
m_thread([=]() {
m_in_channel_addr.store(&g_in_channel);
m_out_channel_addr.store(&g_out_channel);
auto S = m_state.m_ptr;
std::lock_guard<std::mutex> lock(S->m_mutex);
int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0);
@ -382,9 +487,34 @@ public:
return sz_after - m_sz_before;
}
int request_interrupt(lua_State * src) {
lua_pushboolean(src, m_thread.request_interrupt());
return 1;
void request_interrupt() {
while (!m_thread.request_interrupt()) {
check_interrupted();
std::this_thread::sleep_for(g_small_delay);
}
}
void write(lua_State * src, int first, int last) {
while (!m_in_channel_addr) {
check_interrupted();
std::this_thread::sleep_for(g_small_delay);
}
data_channel & in = m_in_channel_addr.load()->get();
in.write(src, first, last);
}
int read(lua_State * src) {
if (!m_out_channel_addr) {
check_interrupted();
std::this_thread::sleep_for(g_small_delay);
}
data_channel & out = m_out_channel_addr.load()->get();
int nargs = lua_gettop(src);
return out.read(src, nargs == 1 ? 0 : 2);
}
bool started() {
return m_in_channel_addr && m_out_channel_addr;
}
};
@ -432,10 +562,21 @@ static int thread_pred(lua_State * L) {
return 1;
}
static int thread_interrupt(lua_State * L) {
return to_thread(L, 1).request_interrupt(L);
static int thread_write(lua_State * L) {
to_thread(L, 1).write(L, 2, lua_gettop(L));
return 0;
}
static int thread_read(lua_State * L) {
return to_thread(L, 1).read(L);
}
static int thread_interrupt(lua_State * L) {
to_thread(L, 1).request_interrupt();
return 0;
}
int thread_wait(lua_State * L) {
return to_thread(L, 1).wait(L);
}
@ -444,6 +585,8 @@ static const struct luaL_Reg thread_m[] = {
{"__gc", thread_gc},
{"wait", safe_function<thread_wait>},
{"interrupt", safe_function<thread_interrupt>},
{"write", safe_function<thread_write>},
{"read", safe_function<thread_read>},
{0, 0}
};
@ -471,5 +614,7 @@ static int sleep(lua_State * L) {
static void open_interrupt(lua_State * L) {
SET_GLOBAL_FUN(check_interrupted, "check_interrupted");
SET_GLOBAL_FUN(sleep, "sleep");
SET_GLOBAL_FUN(channel_read, "read");
SET_GLOBAL_FUN(channel_write, "write");
}
}

View file

@ -19,6 +19,7 @@ class leanlua_state {
struct imp;
std::shared_ptr<imp> m_ptr;
friend class leanlua_thread;
friend class data_channel;
friend int state_dostring(lua_State * L);
friend int state_set_global(lua_State * L);
friend int mk_thread(lua_State * L);

37
tests/lua/th3.lua Normal file
View file

@ -0,0 +1,37 @@
S = State()
T = thread(S, [[
print("starting thread...")
pcall(function()
while true do
check_interrupted() -- check if thread was interrupted
local ok, val = read(10) -- 10 ms timeout
if ok then
print("thread received:", val)
write(val + 10, val - 10) -- send result back to main thread
end
end
end)
]])
for i = 1, 10 do
T:write(10 * i)
local r1 = T:read()
local r2 = T:read()
print("main received: ", r1, r2)
end
T:interrupt()
print("done")
-- Channels are quite flexible, we can send closure over the channel
T = thread(S, [[
for i = 1, 10 do
local f = read()
-- send back the result of f(i)
write(f(i))
end
]])
for i = 1, 10 do
T:write(function (x) return x + i end)
print(T:read())
end