diff --git a/src/bindings/lua/leanlua_state.cpp b/src/bindings/lua/leanlua_state.cpp index dafde5199..12936e4d2 100644 --- a/src/bindings/lua/leanlua_state.cpp +++ b/src/bindings/lua/leanlua_state.cpp @@ -340,18 +340,123 @@ static void open_state(lua_State * L) { SET_GLOBAL_FUN(state_pred, "is_State"); } +// TODO(Leo): allow the user to change it? +#define SMALL_DELAY 10 // in ms +std::chrono::milliseconds g_small_delay(SMALL_DELAY); + +/** + \brief Channel for communicating with thread objects in the Lua API +*/ +class data_channel { + // We use a lua_State to implement the channel. This is quite hackish, + // but it is a convenient storage for Lua objects sent from one state to + // another. + leanlua_state m_channel; + int m_ini; + std::mutex m_mutex; + std::condition_variable m_cv; +public: + data_channel() { + lua_State * channel = m_channel.m_ptr->m_state; + m_ini = lua_gettop(channel); + } + + /** + \brief Copy elements from positions [first, last] from src stack + to the channel. + */ + void write(lua_State * src, int first, int last) { + // write the object on the top of the stack of src to the table + // on m_channel. + if (last < first) + return; + std::lock_guard lock(m_mutex); + lua_State * channel = m_channel.m_ptr->m_state; + bool was_empty = lua_gettop(channel) == m_ini; + copy_values(src, first, last, channel); + if (was_empty) + m_cv.notify_one(); + } + + /** + \brief Retrieve one element from the channel. It will block + the execution of \c tgt if the channel is empty. + */ + int read(lua_State * tgt, int i) { + std::unique_lock lock(m_mutex); + lua_State * channel = m_channel.m_ptr->m_state; + if (i > 0) { + // i is the position of the timeout argument + std::chrono::milliseconds dura(luaL_checkinteger(tgt, i)); + if (lua_gettop(channel) == m_ini) + m_cv.wait_for(lock, dura); + if (lua_gettop(channel) == m_ini) { + // timeout... + lua_pushboolean(tgt, false); + lua_pushnil(tgt); + return 2; + } else { + lua_pushboolean(tgt, true); + copy_values(channel, m_ini + 1, m_ini + 1, tgt); + lua_remove(channel, m_ini + 1); + return 2; + } + } else { + while (lua_gettop(channel) == m_ini) { + check_interrupted(); + m_cv.wait_for(lock, g_small_delay); + } + copy_values(channel, m_ini + 1, m_ini + 1, tgt); + lua_remove(channel, m_ini + 1); + return 1; + } + } +}; + +/** + \brief We want the channels to be lazily created. +*/ +class data_channel_ref { + std::unique_ptr m_channel; + std::mutex m_mutex; +public: + data_channel & get() { + std::lock_guard lock(m_mutex); + if (!m_channel) + m_channel.reset(new data_channel()); + lean_assert(m_channel); + return *m_channel; + } +}; + +data_channel_ref g_in_channel; +data_channel_ref g_out_channel; + +int channel_read(lua_State * L) { + return g_in_channel.get().read(L, lua_gettop(L)); +} + +int channel_write(lua_State * L) { + g_out_channel.get().write(L, 1, lua_gettop(L)); + return 0; +} + class leanlua_thread { - leanlua_state m_state; - int m_sz_before; - bool m_error; - std::string m_error_msg; - interruptible_thread m_thread; + leanlua_state m_state; + int m_sz_before; + bool m_error; + std::string m_error_msg; + interruptible_thread m_thread; + std::atomic m_in_channel_addr; + std::atomic m_out_channel_addr; public: leanlua_thread(leanlua_state const & st, int sz_before, int num_args): m_state(st), m_sz_before(sz_before), m_error(false), m_thread([=]() { + m_in_channel_addr.store(&g_in_channel); + m_out_channel_addr.store(&g_out_channel); auto S = m_state.m_ptr; std::lock_guard lock(S->m_mutex); int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0); @@ -382,9 +487,34 @@ public: return sz_after - m_sz_before; } - int request_interrupt(lua_State * src) { - lua_pushboolean(src, m_thread.request_interrupt()); - return 1; + void request_interrupt() { + while (!m_thread.request_interrupt()) { + check_interrupted(); + std::this_thread::sleep_for(g_small_delay); + } + } + + void write(lua_State * src, int first, int last) { + while (!m_in_channel_addr) { + check_interrupted(); + std::this_thread::sleep_for(g_small_delay); + } + data_channel & in = m_in_channel_addr.load()->get(); + in.write(src, first, last); + } + + int read(lua_State * src) { + if (!m_out_channel_addr) { + check_interrupted(); + std::this_thread::sleep_for(g_small_delay); + } + data_channel & out = m_out_channel_addr.load()->get(); + int nargs = lua_gettop(src); + return out.read(src, nargs == 1 ? 0 : 2); + } + + bool started() { + return m_in_channel_addr && m_out_channel_addr; } }; @@ -432,10 +562,21 @@ static int thread_pred(lua_State * L) { return 1; } -static int thread_interrupt(lua_State * L) { - return to_thread(L, 1).request_interrupt(L); +static int thread_write(lua_State * L) { + to_thread(L, 1).write(L, 2, lua_gettop(L)); + return 0; } +static int thread_read(lua_State * L) { + return to_thread(L, 1).read(L); +} + +static int thread_interrupt(lua_State * L) { + to_thread(L, 1).request_interrupt(); + return 0; +} + + int thread_wait(lua_State * L) { return to_thread(L, 1).wait(L); } @@ -444,6 +585,8 @@ static const struct luaL_Reg thread_m[] = { {"__gc", thread_gc}, {"wait", safe_function}, {"interrupt", safe_function}, + {"write", safe_function}, + {"read", safe_function}, {0, 0} }; @@ -471,5 +614,7 @@ static int sleep(lua_State * L) { static void open_interrupt(lua_State * L) { SET_GLOBAL_FUN(check_interrupted, "check_interrupted"); SET_GLOBAL_FUN(sleep, "sleep"); + SET_GLOBAL_FUN(channel_read, "read"); + SET_GLOBAL_FUN(channel_write, "write"); } } diff --git a/src/bindings/lua/leanlua_state.h b/src/bindings/lua/leanlua_state.h index a9640112c..baba72e59 100644 --- a/src/bindings/lua/leanlua_state.h +++ b/src/bindings/lua/leanlua_state.h @@ -19,6 +19,7 @@ class leanlua_state { struct imp; std::shared_ptr m_ptr; friend class leanlua_thread; + friend class data_channel; friend int state_dostring(lua_State * L); friend int state_set_global(lua_State * L); friend int mk_thread(lua_State * L); diff --git a/tests/lua/th3.lua b/tests/lua/th3.lua new file mode 100644 index 000000000..bd8887c41 --- /dev/null +++ b/tests/lua/th3.lua @@ -0,0 +1,37 @@ +S = State() +T = thread(S, [[ + print("starting thread...") + pcall(function() + while true do + check_interrupted() -- check if thread was interrupted + local ok, val = read(10) -- 10 ms timeout + if ok then + print("thread received:", val) + write(val + 10, val - 10) -- send result back to main thread + end + end + end) +]]) + +for i = 1, 10 do + T:write(10 * i) + local r1 = T:read() + local r2 = T:read() + print("main received: ", r1, r2) +end +T:interrupt() +print("done") + +-- Channels are quite flexible, we can send closure over the channel +T = thread(S, [[ +for i = 1, 10 do + local f = read() + -- send back the result of f(i) + write(f(i)) +end +]]) + +for i = 1, 10 do + T:write(function (x) return x + i end) + print(T:read()) +end