feat(lua): add support for multiple execution threads in the Lua API

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-11-11 16:25:17 -08:00
parent 95785c7aaa
commit 69b41eae70
3 changed files with 169 additions and 50 deletions

View file

@ -6,6 +6,7 @@ Author: Leonardo de Moura
*/ */
#include <iostream> #include <iostream>
#include <mutex> #include <mutex>
#include <thread>
#include <string> #include <string>
#include <lua.hpp> #include <lua.hpp>
#include "util/debug.h" #include "util/debug.h"
@ -29,6 +30,7 @@ extern "C" void * lua_realloc(void *, void * q, size_t, size_t new_size) { retur
namespace lean { namespace lean {
static void open_state(lua_State * L); static void open_state(lua_State * L);
static void open_thread(lua_State * L);
static void copy_values(lua_State * src, int first, int last, lua_State * tgt) { static void copy_values(lua_State * src, int first, int last, lua_State * tgt) {
for (int i = first; i <= last; i++) { for (int i = first; i <= last; i++) {
@ -77,18 +79,19 @@ struct leanlua_state::imp {
if (m_state == nullptr) if (m_state == nullptr)
throw exception("fail to create Lua interpreter"); throw exception("fail to create Lua interpreter");
luaL_openlibs(m_state); luaL_openlibs(m_state);
lean::open_name(m_state); open_name(m_state);
lean::open_mpz(m_state); open_mpz(m_state);
lean::open_mpq(m_state); open_mpq(m_state);
lean::open_options(m_state); open_options(m_state);
lean::open_sexpr(m_state); open_sexpr(m_state);
lean::open_format(m_state); open_format(m_state);
lean::open_level(m_state); open_level(m_state);
lean::open_local_context(m_state); open_local_context(m_state);
lean::open_expr(m_state); open_expr(m_state);
lean::open_context(m_state); open_context(m_state);
lean::open_environment(m_state); open_environment(m_state);
lean::open_state(m_state); open_state(m_state);
open_thread(m_state);
dostring(g_leanlua_extra); dostring(g_leanlua_extra);
} }
@ -110,30 +113,6 @@ struct leanlua_state::imp {
set_environment set(m_state, env); set_environment set(m_state, env);
dostring(str); dostring(str);
} }
int dostring(char const * str, lua_State * src, int first, int last) {
std::lock_guard<std::mutex> lock(m_mutex);
int sz_before = lua_gettop(m_state);
int result = luaL_loadstring(m_state, str);
if (result)
throw lua_exception(lua_tostring(m_state, -1));
copy_values(src, first, last, m_state);
result = lua_pcall(m_state, first > last ? 0 : last - first + 1, LUA_MULTRET, 0);
if (result)
throw lua_exception(lua_tostring(m_state, -1));
int sz_after = lua_gettop(m_state);
if (sz_after > sz_before) {
copy_values(m_state, sz_before + 1, sz_after, src);
lua_pop(m_state, sz_after - sz_before);
}
return sz_after - sz_before;
}
}; };
leanlua_state::leanlua_state(): leanlua_state::leanlua_state():
@ -155,10 +134,6 @@ void leanlua_state::dostring(char const * str, environment & env) {
m_ptr->dostring(str, env); m_ptr->dostring(str, env);
} }
int leanlua_state::dostring(char const * str, lua_State * src, int first, int last) {
return m_ptr->dostring(str, src, first, last);
}
constexpr char const * state_mt = "state.mt"; constexpr char const * state_mt = "state.mt";
bool is_state(lua_State * L, int idx) { bool is_state(lua_State * L, int idx) {
@ -187,8 +162,37 @@ static int state_gc(lua_State * L) {
return 0; return 0;
} }
static int state_dostring(lua_State * L) { int state_dostring(lua_State * L) {
return to_state(L, 1).dostring(luaL_checkstring(L, 2), L, 3, lua_gettop(L)); auto S = to_state(L, 1).m_ptr;
char const * script = luaL_checkstring(L, 2);
int first = 3;
int last = lua_gettop(L);
std::lock_guard<std::mutex> lock(S->m_mutex);
int sz_before = lua_gettop(S->m_state);
int result = luaL_loadstring(S->m_state, script);
if (result)
throw lua_exception(lua_tostring(S->m_state, -1));
copy_values(L, first, last, S->m_state);
result = lua_pcall(S->m_state, first > last ? 0 : last - first + 1, LUA_MULTRET, 0);
if (result)
throw lua_exception(lua_tostring(S->m_state, -1));
int sz_after = lua_gettop(S->m_state);
if (sz_after > sz_before) {
copy_values(S->m_state, sz_before + 1, sz_after, L);
lua_pop(S->m_state, sz_after - sz_before);
}
return sz_after - sz_before;
}
static int state_pred(lua_State * L) {
lua_pushboolean(L, is_state(L, 1));
return 1;
} }
static const struct luaL_Reg state_m[] = { static const struct luaL_Reg state_m[] = {
@ -204,5 +208,112 @@ static void open_state(lua_State * L) {
setfuncs(L, state_m, 0); setfuncs(L, state_m, 0);
set_global_function<mk_state>(L, "State"); set_global_function<mk_state>(L, "State");
set_global_function<state_pred>(L, "is_State");
}
class leanlua_thread {
leanlua_state m_state;
int m_sz_before;
bool m_error;
std::string m_error_msg;
std::thread m_thread;
public:
leanlua_thread(leanlua_state const & st, int sz_before, int num_args):
m_state(st),
m_sz_before(sz_before),
m_error(false),
m_thread([=]() {
auto S = m_state.m_ptr;
std::lock_guard<std::mutex> lock(S->m_mutex);
int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0);
if (result) {
m_error = true;
m_error_msg = lua_tostring(S->m_state, -1);
return;
}
}) {
}
~leanlua_thread() {
if (m_thread.joinable())
m_thread.join();
}
int wait(lua_State * src) {
m_thread.join();
if (m_error)
throw lua_exception(m_error_msg.c_str());
auto S = m_state.m_ptr;
int sz_after = lua_gettop(S->m_state);
if (sz_after > m_sz_before) {
copy_values(S->m_state, m_sz_before + 1, sz_after, src);
lua_pop(S->m_state, sz_after - m_sz_before);
}
return sz_after - m_sz_before;
}
};
constexpr char const * thread_mt = "thread.mt";
bool is_thread(lua_State * L, int idx) {
return testudata(L, idx, thread_mt);
}
leanlua_thread & to_thread(lua_State * L, int idx) {
return *static_cast<leanlua_thread*>(luaL_checkudata(L, idx, thread_mt));
}
int mk_thread(lua_State * L) {
leanlua_state & st = to_state(L, 1);
char const * script = luaL_checkstring(L, 2);
int first = 3;
int last = lua_gettop(L);
int nargs = first > last ? 0 : last - first + 1;
int sz_before;
auto S = st.m_ptr;
{
std::lock_guard<std::mutex> lock(S->m_mutex);
sz_before = lua_gettop(S->m_state);
int result = luaL_loadstring(S->m_state, script);
if (result)
throw lua_exception(lua_tostring(S->m_state, -1));
copy_values(L, first, last, S->m_state);
}
void * mem = lua_newuserdata(L, sizeof(leanlua_thread));
new (mem) leanlua_thread(st, sz_before, nargs);
luaL_getmetatable(L, thread_mt);
lua_setmetatable(L, -2);
return 1;
}
static int thread_gc(lua_State * L) {
to_thread(L, 1).~leanlua_thread();
return 0;
}
static int thread_pred(lua_State * L) {
lua_pushboolean(L, is_thread(L, 1));
return 1;
}
int thread_wait(lua_State * L) {
return to_thread(L, 1).wait(L);
}
static const struct luaL_Reg thread_m[] = {
{"__gc", thread_gc},
{"wait", safe_function<thread_wait>},
{0, 0}
};
static void open_thread(lua_State * L) {
luaL_newmetatable(L, thread_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, thread_m, 0);
set_global_function<mk_thread>(L, "thread");
set_global_function<thread_pred>(L, "is_thread");
} }
} }

View file

@ -17,6 +17,10 @@ class environment;
class leanlua_state { class leanlua_state {
struct imp; struct imp;
std::shared_ptr<imp> m_ptr; std::shared_ptr<imp> m_ptr;
friend class leanlua_thread;
friend int state_dostring(lua_State *);
friend int mk_thread(lua_State *);
friend int thread_wait(lua_State *);
public: public:
leanlua_state(); leanlua_state();
~leanlua_state(); ~leanlua_state();
@ -38,13 +42,5 @@ public:
The script \c str should not store a reference to the environment \c env. The script \c str should not store a reference to the environment \c env.
*/ */
void dostring(char const * str, environment & env); void dostring(char const * str, environment & env);
/**
\brief Execute the given script, but copy the values at positions <tt>[first, last]</tt> from the stack of \c src.
The values are passed as arguments to the script \c str.
The values returned by the script \c str are copied back to the stack of \c src.
The result is the number of values returned by the script \c str.
*/
int dostring(char const * str, lua_State * src, int first, int last);
}; };
} }

12
tests/lua/threads/th1.lua Normal file
View file

@ -0,0 +1,12 @@
f = Const("f")
a = Const("a")
S = State()
T = thread(S, [[
t = ...
g = Const("g")
return g(t)
]], f(a))
r = T:wait()
print(r)