feat(lua): communication channels for threads
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
bd1e9c7548
commit
b5dcb93550
3 changed files with 193 additions and 10 deletions
|
@ -340,18 +340,123 @@ static void open_state(lua_State * L) {
|
||||||
SET_GLOBAL_FUN(state_pred, "is_State");
|
SET_GLOBAL_FUN(state_pred, "is_State");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(Leo): allow the user to change it?
|
||||||
|
#define SMALL_DELAY 10 // in ms
|
||||||
|
std::chrono::milliseconds g_small_delay(SMALL_DELAY);
|
||||||
|
|
||||||
|
/**
|
||||||
|
\brief Channel for communicating with thread objects in the Lua API
|
||||||
|
*/
|
||||||
|
class data_channel {
|
||||||
|
// We use a lua_State to implement the channel. This is quite hackish,
|
||||||
|
// but it is a convenient storage for Lua objects sent from one state to
|
||||||
|
// another.
|
||||||
|
leanlua_state m_channel;
|
||||||
|
int m_ini;
|
||||||
|
std::mutex m_mutex;
|
||||||
|
std::condition_variable m_cv;
|
||||||
|
public:
|
||||||
|
data_channel() {
|
||||||
|
lua_State * channel = m_channel.m_ptr->m_state;
|
||||||
|
m_ini = lua_gettop(channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
\brief Copy elements from positions [first, last] from src stack
|
||||||
|
to the channel.
|
||||||
|
*/
|
||||||
|
void write(lua_State * src, int first, int last) {
|
||||||
|
// write the object on the top of the stack of src to the table
|
||||||
|
// on m_channel.
|
||||||
|
if (last < first)
|
||||||
|
return;
|
||||||
|
std::lock_guard<std::mutex> lock(m_mutex);
|
||||||
|
lua_State * channel = m_channel.m_ptr->m_state;
|
||||||
|
bool was_empty = lua_gettop(channel) == m_ini;
|
||||||
|
copy_values(src, first, last, channel);
|
||||||
|
if (was_empty)
|
||||||
|
m_cv.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
\brief Retrieve one element from the channel. It will block
|
||||||
|
the execution of \c tgt if the channel is empty.
|
||||||
|
*/
|
||||||
|
int read(lua_State * tgt, int i) {
|
||||||
|
std::unique_lock<std::mutex> lock(m_mutex);
|
||||||
|
lua_State * channel = m_channel.m_ptr->m_state;
|
||||||
|
if (i > 0) {
|
||||||
|
// i is the position of the timeout argument
|
||||||
|
std::chrono::milliseconds dura(luaL_checkinteger(tgt, i));
|
||||||
|
if (lua_gettop(channel) == m_ini)
|
||||||
|
m_cv.wait_for(lock, dura);
|
||||||
|
if (lua_gettop(channel) == m_ini) {
|
||||||
|
// timeout...
|
||||||
|
lua_pushboolean(tgt, false);
|
||||||
|
lua_pushnil(tgt);
|
||||||
|
return 2;
|
||||||
|
} else {
|
||||||
|
lua_pushboolean(tgt, true);
|
||||||
|
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
|
||||||
|
lua_remove(channel, m_ini + 1);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
while (lua_gettop(channel) == m_ini) {
|
||||||
|
check_interrupted();
|
||||||
|
m_cv.wait_for(lock, g_small_delay);
|
||||||
|
}
|
||||||
|
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
|
||||||
|
lua_remove(channel, m_ini + 1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
\brief We want the channels to be lazily created.
|
||||||
|
*/
|
||||||
|
class data_channel_ref {
|
||||||
|
std::unique_ptr<data_channel> m_channel;
|
||||||
|
std::mutex m_mutex;
|
||||||
|
public:
|
||||||
|
data_channel & get() {
|
||||||
|
std::lock_guard<std::mutex> lock(m_mutex);
|
||||||
|
if (!m_channel)
|
||||||
|
m_channel.reset(new data_channel());
|
||||||
|
lean_assert(m_channel);
|
||||||
|
return *m_channel;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
data_channel_ref g_in_channel;
|
||||||
|
data_channel_ref g_out_channel;
|
||||||
|
|
||||||
|
int channel_read(lua_State * L) {
|
||||||
|
return g_in_channel.get().read(L, lua_gettop(L));
|
||||||
|
}
|
||||||
|
|
||||||
|
int channel_write(lua_State * L) {
|
||||||
|
g_out_channel.get().write(L, 1, lua_gettop(L));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
class leanlua_thread {
|
class leanlua_thread {
|
||||||
leanlua_state m_state;
|
leanlua_state m_state;
|
||||||
int m_sz_before;
|
int m_sz_before;
|
||||||
bool m_error;
|
bool m_error;
|
||||||
std::string m_error_msg;
|
std::string m_error_msg;
|
||||||
interruptible_thread m_thread;
|
interruptible_thread m_thread;
|
||||||
|
std::atomic<data_channel_ref *> m_in_channel_addr;
|
||||||
|
std::atomic<data_channel_ref *> m_out_channel_addr;
|
||||||
public:
|
public:
|
||||||
leanlua_thread(leanlua_state const & st, int sz_before, int num_args):
|
leanlua_thread(leanlua_state const & st, int sz_before, int num_args):
|
||||||
m_state(st),
|
m_state(st),
|
||||||
m_sz_before(sz_before),
|
m_sz_before(sz_before),
|
||||||
m_error(false),
|
m_error(false),
|
||||||
m_thread([=]() {
|
m_thread([=]() {
|
||||||
|
m_in_channel_addr.store(&g_in_channel);
|
||||||
|
m_out_channel_addr.store(&g_out_channel);
|
||||||
auto S = m_state.m_ptr;
|
auto S = m_state.m_ptr;
|
||||||
std::lock_guard<std::mutex> lock(S->m_mutex);
|
std::lock_guard<std::mutex> lock(S->m_mutex);
|
||||||
int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0);
|
int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0);
|
||||||
|
@ -382,9 +487,34 @@ public:
|
||||||
return sz_after - m_sz_before;
|
return sz_after - m_sz_before;
|
||||||
}
|
}
|
||||||
|
|
||||||
int request_interrupt(lua_State * src) {
|
void request_interrupt() {
|
||||||
lua_pushboolean(src, m_thread.request_interrupt());
|
while (!m_thread.request_interrupt()) {
|
||||||
return 1;
|
check_interrupted();
|
||||||
|
std::this_thread::sleep_for(g_small_delay);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void write(lua_State * src, int first, int last) {
|
||||||
|
while (!m_in_channel_addr) {
|
||||||
|
check_interrupted();
|
||||||
|
std::this_thread::sleep_for(g_small_delay);
|
||||||
|
}
|
||||||
|
data_channel & in = m_in_channel_addr.load()->get();
|
||||||
|
in.write(src, first, last);
|
||||||
|
}
|
||||||
|
|
||||||
|
int read(lua_State * src) {
|
||||||
|
if (!m_out_channel_addr) {
|
||||||
|
check_interrupted();
|
||||||
|
std::this_thread::sleep_for(g_small_delay);
|
||||||
|
}
|
||||||
|
data_channel & out = m_out_channel_addr.load()->get();
|
||||||
|
int nargs = lua_gettop(src);
|
||||||
|
return out.read(src, nargs == 1 ? 0 : 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool started() {
|
||||||
|
return m_in_channel_addr && m_out_channel_addr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -432,10 +562,21 @@ static int thread_pred(lua_State * L) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int thread_interrupt(lua_State * L) {
|
static int thread_write(lua_State * L) {
|
||||||
return to_thread(L, 1).request_interrupt(L);
|
to_thread(L, 1).write(L, 2, lua_gettop(L));
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int thread_read(lua_State * L) {
|
||||||
|
return to_thread(L, 1).read(L);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int thread_interrupt(lua_State * L) {
|
||||||
|
to_thread(L, 1).request_interrupt();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
int thread_wait(lua_State * L) {
|
int thread_wait(lua_State * L) {
|
||||||
return to_thread(L, 1).wait(L);
|
return to_thread(L, 1).wait(L);
|
||||||
}
|
}
|
||||||
|
@ -444,6 +585,8 @@ static const struct luaL_Reg thread_m[] = {
|
||||||
{"__gc", thread_gc},
|
{"__gc", thread_gc},
|
||||||
{"wait", safe_function<thread_wait>},
|
{"wait", safe_function<thread_wait>},
|
||||||
{"interrupt", safe_function<thread_interrupt>},
|
{"interrupt", safe_function<thread_interrupt>},
|
||||||
|
{"write", safe_function<thread_write>},
|
||||||
|
{"read", safe_function<thread_read>},
|
||||||
{0, 0}
|
{0, 0}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -471,5 +614,7 @@ static int sleep(lua_State * L) {
|
||||||
static void open_interrupt(lua_State * L) {
|
static void open_interrupt(lua_State * L) {
|
||||||
SET_GLOBAL_FUN(check_interrupted, "check_interrupted");
|
SET_GLOBAL_FUN(check_interrupted, "check_interrupted");
|
||||||
SET_GLOBAL_FUN(sleep, "sleep");
|
SET_GLOBAL_FUN(sleep, "sleep");
|
||||||
|
SET_GLOBAL_FUN(channel_read, "read");
|
||||||
|
SET_GLOBAL_FUN(channel_write, "write");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ class leanlua_state {
|
||||||
struct imp;
|
struct imp;
|
||||||
std::shared_ptr<imp> m_ptr;
|
std::shared_ptr<imp> m_ptr;
|
||||||
friend class leanlua_thread;
|
friend class leanlua_thread;
|
||||||
|
friend class data_channel;
|
||||||
friend int state_dostring(lua_State * L);
|
friend int state_dostring(lua_State * L);
|
||||||
friend int state_set_global(lua_State * L);
|
friend int state_set_global(lua_State * L);
|
||||||
friend int mk_thread(lua_State * L);
|
friend int mk_thread(lua_State * L);
|
||||||
|
|
37
tests/lua/th3.lua
Normal file
37
tests/lua/th3.lua
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
S = State()
|
||||||
|
T = thread(S, [[
|
||||||
|
print("starting thread...")
|
||||||
|
pcall(function()
|
||||||
|
while true do
|
||||||
|
check_interrupted() -- check if thread was interrupted
|
||||||
|
local ok, val = read(10) -- 10 ms timeout
|
||||||
|
if ok then
|
||||||
|
print("thread received:", val)
|
||||||
|
write(val + 10, val - 10) -- send result back to main thread
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end)
|
||||||
|
]])
|
||||||
|
|
||||||
|
for i = 1, 10 do
|
||||||
|
T:write(10 * i)
|
||||||
|
local r1 = T:read()
|
||||||
|
local r2 = T:read()
|
||||||
|
print("main received: ", r1, r2)
|
||||||
|
end
|
||||||
|
T:interrupt()
|
||||||
|
print("done")
|
||||||
|
|
||||||
|
-- Channels are quite flexible, we can send closure over the channel
|
||||||
|
T = thread(S, [[
|
||||||
|
for i = 1, 10 do
|
||||||
|
local f = read()
|
||||||
|
-- send back the result of f(i)
|
||||||
|
write(f(i))
|
||||||
|
end
|
||||||
|
]])
|
||||||
|
|
||||||
|
for i = 1, 10 do
|
||||||
|
T:write(function (x) return x + i end)
|
||||||
|
print(T:read())
|
||||||
|
end
|
Loading…
Reference in a new issue