feat(lua): communication channels for threads
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
bd1e9c7548
commit
b5dcb93550
3 changed files with 193 additions and 10 deletions
|
@ -340,18 +340,123 @@ static void open_state(lua_State * L) {
|
|||
SET_GLOBAL_FUN(state_pred, "is_State");
|
||||
}
|
||||
|
||||
// TODO(Leo): allow the user to change it?
|
||||
#define SMALL_DELAY 10 // in ms
|
||||
std::chrono::milliseconds g_small_delay(SMALL_DELAY);
|
||||
|
||||
/**
|
||||
\brief Channel for communicating with thread objects in the Lua API
|
||||
*/
|
||||
class data_channel {
|
||||
// We use a lua_State to implement the channel. This is quite hackish,
|
||||
// but it is a convenient storage for Lua objects sent from one state to
|
||||
// another.
|
||||
leanlua_state m_channel;
|
||||
int m_ini;
|
||||
std::mutex m_mutex;
|
||||
std::condition_variable m_cv;
|
||||
public:
|
||||
data_channel() {
|
||||
lua_State * channel = m_channel.m_ptr->m_state;
|
||||
m_ini = lua_gettop(channel);
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Copy elements from positions [first, last] from src stack
|
||||
to the channel.
|
||||
*/
|
||||
void write(lua_State * src, int first, int last) {
|
||||
// write the object on the top of the stack of src to the table
|
||||
// on m_channel.
|
||||
if (last < first)
|
||||
return;
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
lua_State * channel = m_channel.m_ptr->m_state;
|
||||
bool was_empty = lua_gettop(channel) == m_ini;
|
||||
copy_values(src, first, last, channel);
|
||||
if (was_empty)
|
||||
m_cv.notify_one();
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Retrieve one element from the channel. It will block
|
||||
the execution of \c tgt if the channel is empty.
|
||||
*/
|
||||
int read(lua_State * tgt, int i) {
|
||||
std::unique_lock<std::mutex> lock(m_mutex);
|
||||
lua_State * channel = m_channel.m_ptr->m_state;
|
||||
if (i > 0) {
|
||||
// i is the position of the timeout argument
|
||||
std::chrono::milliseconds dura(luaL_checkinteger(tgt, i));
|
||||
if (lua_gettop(channel) == m_ini)
|
||||
m_cv.wait_for(lock, dura);
|
||||
if (lua_gettop(channel) == m_ini) {
|
||||
// timeout...
|
||||
lua_pushboolean(tgt, false);
|
||||
lua_pushnil(tgt);
|
||||
return 2;
|
||||
} else {
|
||||
lua_pushboolean(tgt, true);
|
||||
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
|
||||
lua_remove(channel, m_ini + 1);
|
||||
return 2;
|
||||
}
|
||||
} else {
|
||||
while (lua_gettop(channel) == m_ini) {
|
||||
check_interrupted();
|
||||
m_cv.wait_for(lock, g_small_delay);
|
||||
}
|
||||
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
|
||||
lua_remove(channel, m_ini + 1);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
\brief We want the channels to be lazily created.
|
||||
*/
|
||||
class data_channel_ref {
|
||||
std::unique_ptr<data_channel> m_channel;
|
||||
std::mutex m_mutex;
|
||||
public:
|
||||
data_channel & get() {
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
if (!m_channel)
|
||||
m_channel.reset(new data_channel());
|
||||
lean_assert(m_channel);
|
||||
return *m_channel;
|
||||
}
|
||||
};
|
||||
|
||||
data_channel_ref g_in_channel;
|
||||
data_channel_ref g_out_channel;
|
||||
|
||||
int channel_read(lua_State * L) {
|
||||
return g_in_channel.get().read(L, lua_gettop(L));
|
||||
}
|
||||
|
||||
int channel_write(lua_State * L) {
|
||||
g_out_channel.get().write(L, 1, lua_gettop(L));
|
||||
return 0;
|
||||
}
|
||||
|
||||
class leanlua_thread {
|
||||
leanlua_state m_state;
|
||||
int m_sz_before;
|
||||
bool m_error;
|
||||
std::string m_error_msg;
|
||||
interruptible_thread m_thread;
|
||||
std::atomic<data_channel_ref *> m_in_channel_addr;
|
||||
std::atomic<data_channel_ref *> m_out_channel_addr;
|
||||
public:
|
||||
leanlua_thread(leanlua_state const & st, int sz_before, int num_args):
|
||||
m_state(st),
|
||||
m_sz_before(sz_before),
|
||||
m_error(false),
|
||||
m_thread([=]() {
|
||||
m_in_channel_addr.store(&g_in_channel);
|
||||
m_out_channel_addr.store(&g_out_channel);
|
||||
auto S = m_state.m_ptr;
|
||||
std::lock_guard<std::mutex> lock(S->m_mutex);
|
||||
int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0);
|
||||
|
@ -382,9 +487,34 @@ public:
|
|||
return sz_after - m_sz_before;
|
||||
}
|
||||
|
||||
int request_interrupt(lua_State * src) {
|
||||
lua_pushboolean(src, m_thread.request_interrupt());
|
||||
return 1;
|
||||
void request_interrupt() {
|
||||
while (!m_thread.request_interrupt()) {
|
||||
check_interrupted();
|
||||
std::this_thread::sleep_for(g_small_delay);
|
||||
}
|
||||
}
|
||||
|
||||
void write(lua_State * src, int first, int last) {
|
||||
while (!m_in_channel_addr) {
|
||||
check_interrupted();
|
||||
std::this_thread::sleep_for(g_small_delay);
|
||||
}
|
||||
data_channel & in = m_in_channel_addr.load()->get();
|
||||
in.write(src, first, last);
|
||||
}
|
||||
|
||||
int read(lua_State * src) {
|
||||
if (!m_out_channel_addr) {
|
||||
check_interrupted();
|
||||
std::this_thread::sleep_for(g_small_delay);
|
||||
}
|
||||
data_channel & out = m_out_channel_addr.load()->get();
|
||||
int nargs = lua_gettop(src);
|
||||
return out.read(src, nargs == 1 ? 0 : 2);
|
||||
}
|
||||
|
||||
bool started() {
|
||||
return m_in_channel_addr && m_out_channel_addr;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -432,10 +562,21 @@ static int thread_pred(lua_State * L) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
static int thread_interrupt(lua_State * L) {
|
||||
return to_thread(L, 1).request_interrupt(L);
|
||||
static int thread_write(lua_State * L) {
|
||||
to_thread(L, 1).write(L, 2, lua_gettop(L));
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int thread_read(lua_State * L) {
|
||||
return to_thread(L, 1).read(L);
|
||||
}
|
||||
|
||||
static int thread_interrupt(lua_State * L) {
|
||||
to_thread(L, 1).request_interrupt();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
int thread_wait(lua_State * L) {
|
||||
return to_thread(L, 1).wait(L);
|
||||
}
|
||||
|
@ -444,6 +585,8 @@ static const struct luaL_Reg thread_m[] = {
|
|||
{"__gc", thread_gc},
|
||||
{"wait", safe_function<thread_wait>},
|
||||
{"interrupt", safe_function<thread_interrupt>},
|
||||
{"write", safe_function<thread_write>},
|
||||
{"read", safe_function<thread_read>},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
|
@ -471,5 +614,7 @@ static int sleep(lua_State * L) {
|
|||
static void open_interrupt(lua_State * L) {
|
||||
SET_GLOBAL_FUN(check_interrupted, "check_interrupted");
|
||||
SET_GLOBAL_FUN(sleep, "sleep");
|
||||
SET_GLOBAL_FUN(channel_read, "read");
|
||||
SET_GLOBAL_FUN(channel_write, "write");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ class leanlua_state {
|
|||
struct imp;
|
||||
std::shared_ptr<imp> m_ptr;
|
||||
friend class leanlua_thread;
|
||||
friend class data_channel;
|
||||
friend int state_dostring(lua_State * L);
|
||||
friend int state_set_global(lua_State * L);
|
||||
friend int mk_thread(lua_State * L);
|
||||
|
|
37
tests/lua/th3.lua
Normal file
37
tests/lua/th3.lua
Normal file
|
@ -0,0 +1,37 @@
|
|||
S = State()
|
||||
T = thread(S, [[
|
||||
print("starting thread...")
|
||||
pcall(function()
|
||||
while true do
|
||||
check_interrupted() -- check if thread was interrupted
|
||||
local ok, val = read(10) -- 10 ms timeout
|
||||
if ok then
|
||||
print("thread received:", val)
|
||||
write(val + 10, val - 10) -- send result back to main thread
|
||||
end
|
||||
end
|
||||
end)
|
||||
]])
|
||||
|
||||
for i = 1, 10 do
|
||||
T:write(10 * i)
|
||||
local r1 = T:read()
|
||||
local r2 = T:read()
|
||||
print("main received: ", r1, r2)
|
||||
end
|
||||
T:interrupt()
|
||||
print("done")
|
||||
|
||||
-- Channels are quite flexible, we can send closure over the channel
|
||||
T = thread(S, [[
|
||||
for i = 1, 10 do
|
||||
local f = read()
|
||||
-- send back the result of f(i)
|
||||
write(f(i))
|
||||
end
|
||||
]])
|
||||
|
||||
for i = 1, 10 do
|
||||
T:write(function (x) return x + i end)
|
||||
print(T:read())
|
||||
end
|
Loading…
Reference in a new issue