feat(lua): add support for multiple execution threads in the Lua API
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
95785c7aaa
commit
69b41eae70
3 changed files with 169 additions and 50 deletions
|
@ -6,6 +6,7 @@ Author: Leonardo de Moura
|
||||||
*/
|
*/
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <thread>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <lua.hpp>
|
#include <lua.hpp>
|
||||||
#include "util/debug.h"
|
#include "util/debug.h"
|
||||||
|
@ -29,6 +30,7 @@ extern "C" void * lua_realloc(void *, void * q, size_t, size_t new_size) { retur
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
static void open_state(lua_State * L);
|
static void open_state(lua_State * L);
|
||||||
|
static void open_thread(lua_State * L);
|
||||||
|
|
||||||
static void copy_values(lua_State * src, int first, int last, lua_State * tgt) {
|
static void copy_values(lua_State * src, int first, int last, lua_State * tgt) {
|
||||||
for (int i = first; i <= last; i++) {
|
for (int i = first; i <= last; i++) {
|
||||||
|
@ -77,18 +79,19 @@ struct leanlua_state::imp {
|
||||||
if (m_state == nullptr)
|
if (m_state == nullptr)
|
||||||
throw exception("fail to create Lua interpreter");
|
throw exception("fail to create Lua interpreter");
|
||||||
luaL_openlibs(m_state);
|
luaL_openlibs(m_state);
|
||||||
lean::open_name(m_state);
|
open_name(m_state);
|
||||||
lean::open_mpz(m_state);
|
open_mpz(m_state);
|
||||||
lean::open_mpq(m_state);
|
open_mpq(m_state);
|
||||||
lean::open_options(m_state);
|
open_options(m_state);
|
||||||
lean::open_sexpr(m_state);
|
open_sexpr(m_state);
|
||||||
lean::open_format(m_state);
|
open_format(m_state);
|
||||||
lean::open_level(m_state);
|
open_level(m_state);
|
||||||
lean::open_local_context(m_state);
|
open_local_context(m_state);
|
||||||
lean::open_expr(m_state);
|
open_expr(m_state);
|
||||||
lean::open_context(m_state);
|
open_context(m_state);
|
||||||
lean::open_environment(m_state);
|
open_environment(m_state);
|
||||||
lean::open_state(m_state);
|
open_state(m_state);
|
||||||
|
open_thread(m_state);
|
||||||
dostring(g_leanlua_extra);
|
dostring(g_leanlua_extra);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,30 +113,6 @@ struct leanlua_state::imp {
|
||||||
set_environment set(m_state, env);
|
set_environment set(m_state, env);
|
||||||
dostring(str);
|
dostring(str);
|
||||||
}
|
}
|
||||||
|
|
||||||
int dostring(char const * str, lua_State * src, int first, int last) {
|
|
||||||
std::lock_guard<std::mutex> lock(m_mutex);
|
|
||||||
|
|
||||||
int sz_before = lua_gettop(m_state);
|
|
||||||
|
|
||||||
int result = luaL_loadstring(m_state, str);
|
|
||||||
if (result)
|
|
||||||
throw lua_exception(lua_tostring(m_state, -1));
|
|
||||||
|
|
||||||
copy_values(src, first, last, m_state);
|
|
||||||
|
|
||||||
result = lua_pcall(m_state, first > last ? 0 : last - first + 1, LUA_MULTRET, 0);
|
|
||||||
if (result)
|
|
||||||
throw lua_exception(lua_tostring(m_state, -1));
|
|
||||||
|
|
||||||
int sz_after = lua_gettop(m_state);
|
|
||||||
|
|
||||||
if (sz_after > sz_before) {
|
|
||||||
copy_values(m_state, sz_before + 1, sz_after, src);
|
|
||||||
lua_pop(m_state, sz_after - sz_before);
|
|
||||||
}
|
|
||||||
return sz_after - sz_before;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
leanlua_state::leanlua_state():
|
leanlua_state::leanlua_state():
|
||||||
|
@ -155,10 +134,6 @@ void leanlua_state::dostring(char const * str, environment & env) {
|
||||||
m_ptr->dostring(str, env);
|
m_ptr->dostring(str, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
int leanlua_state::dostring(char const * str, lua_State * src, int first, int last) {
|
|
||||||
return m_ptr->dostring(str, src, first, last);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr char const * state_mt = "state.mt";
|
constexpr char const * state_mt = "state.mt";
|
||||||
|
|
||||||
bool is_state(lua_State * L, int idx) {
|
bool is_state(lua_State * L, int idx) {
|
||||||
|
@ -187,8 +162,37 @@ static int state_gc(lua_State * L) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int state_dostring(lua_State * L) {
|
int state_dostring(lua_State * L) {
|
||||||
return to_state(L, 1).dostring(luaL_checkstring(L, 2), L, 3, lua_gettop(L));
|
auto S = to_state(L, 1).m_ptr;
|
||||||
|
char const * script = luaL_checkstring(L, 2);
|
||||||
|
int first = 3;
|
||||||
|
int last = lua_gettop(L);
|
||||||
|
std::lock_guard<std::mutex> lock(S->m_mutex);
|
||||||
|
|
||||||
|
int sz_before = lua_gettop(S->m_state);
|
||||||
|
|
||||||
|
int result = luaL_loadstring(S->m_state, script);
|
||||||
|
if (result)
|
||||||
|
throw lua_exception(lua_tostring(S->m_state, -1));
|
||||||
|
|
||||||
|
copy_values(L, first, last, S->m_state);
|
||||||
|
|
||||||
|
result = lua_pcall(S->m_state, first > last ? 0 : last - first + 1, LUA_MULTRET, 0);
|
||||||
|
if (result)
|
||||||
|
throw lua_exception(lua_tostring(S->m_state, -1));
|
||||||
|
|
||||||
|
int sz_after = lua_gettop(S->m_state);
|
||||||
|
|
||||||
|
if (sz_after > sz_before) {
|
||||||
|
copy_values(S->m_state, sz_before + 1, sz_after, L);
|
||||||
|
lua_pop(S->m_state, sz_after - sz_before);
|
||||||
|
}
|
||||||
|
return sz_after - sz_before;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int state_pred(lua_State * L) {
|
||||||
|
lua_pushboolean(L, is_state(L, 1));
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const struct luaL_Reg state_m[] = {
|
static const struct luaL_Reg state_m[] = {
|
||||||
|
@ -204,5 +208,112 @@ static void open_state(lua_State * L) {
|
||||||
setfuncs(L, state_m, 0);
|
setfuncs(L, state_m, 0);
|
||||||
|
|
||||||
set_global_function<mk_state>(L, "State");
|
set_global_function<mk_state>(L, "State");
|
||||||
|
set_global_function<state_pred>(L, "is_State");
|
||||||
|
}
|
||||||
|
|
||||||
|
class leanlua_thread {
|
||||||
|
leanlua_state m_state;
|
||||||
|
int m_sz_before;
|
||||||
|
bool m_error;
|
||||||
|
std::string m_error_msg;
|
||||||
|
std::thread m_thread;
|
||||||
|
public:
|
||||||
|
leanlua_thread(leanlua_state const & st, int sz_before, int num_args):
|
||||||
|
m_state(st),
|
||||||
|
m_sz_before(sz_before),
|
||||||
|
m_error(false),
|
||||||
|
m_thread([=]() {
|
||||||
|
auto S = m_state.m_ptr;
|
||||||
|
std::lock_guard<std::mutex> lock(S->m_mutex);
|
||||||
|
int result = lua_pcall(S->m_state, num_args, LUA_MULTRET, 0);
|
||||||
|
if (result) {
|
||||||
|
m_error = true;
|
||||||
|
m_error_msg = lua_tostring(S->m_state, -1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
}
|
||||||
|
|
||||||
|
~leanlua_thread() {
|
||||||
|
if (m_thread.joinable())
|
||||||
|
m_thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
int wait(lua_State * src) {
|
||||||
|
m_thread.join();
|
||||||
|
if (m_error)
|
||||||
|
throw lua_exception(m_error_msg.c_str());
|
||||||
|
auto S = m_state.m_ptr;
|
||||||
|
int sz_after = lua_gettop(S->m_state);
|
||||||
|
|
||||||
|
if (sz_after > m_sz_before) {
|
||||||
|
copy_values(S->m_state, m_sz_before + 1, sz_after, src);
|
||||||
|
lua_pop(S->m_state, sz_after - m_sz_before);
|
||||||
|
}
|
||||||
|
return sz_after - m_sz_before;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr char const * thread_mt = "thread.mt";
|
||||||
|
|
||||||
|
bool is_thread(lua_State * L, int idx) {
|
||||||
|
return testudata(L, idx, thread_mt);
|
||||||
|
}
|
||||||
|
|
||||||
|
leanlua_thread & to_thread(lua_State * L, int idx) {
|
||||||
|
return *static_cast<leanlua_thread*>(luaL_checkudata(L, idx, thread_mt));
|
||||||
|
}
|
||||||
|
|
||||||
|
int mk_thread(lua_State * L) {
|
||||||
|
leanlua_state & st = to_state(L, 1);
|
||||||
|
char const * script = luaL_checkstring(L, 2);
|
||||||
|
int first = 3;
|
||||||
|
int last = lua_gettop(L);
|
||||||
|
int nargs = first > last ? 0 : last - first + 1;
|
||||||
|
int sz_before;
|
||||||
|
auto S = st.m_ptr;
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(S->m_mutex);
|
||||||
|
sz_before = lua_gettop(S->m_state);
|
||||||
|
int result = luaL_loadstring(S->m_state, script);
|
||||||
|
if (result)
|
||||||
|
throw lua_exception(lua_tostring(S->m_state, -1));
|
||||||
|
copy_values(L, first, last, S->m_state);
|
||||||
|
}
|
||||||
|
void * mem = lua_newuserdata(L, sizeof(leanlua_thread));
|
||||||
|
new (mem) leanlua_thread(st, sz_before, nargs);
|
||||||
|
luaL_getmetatable(L, thread_mt);
|
||||||
|
lua_setmetatable(L, -2);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int thread_gc(lua_State * L) {
|
||||||
|
to_thread(L, 1).~leanlua_thread();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int thread_pred(lua_State * L) {
|
||||||
|
lua_pushboolean(L, is_thread(L, 1));
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int thread_wait(lua_State * L) {
|
||||||
|
return to_thread(L, 1).wait(L);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const struct luaL_Reg thread_m[] = {
|
||||||
|
{"__gc", thread_gc},
|
||||||
|
{"wait", safe_function<thread_wait>},
|
||||||
|
{0, 0}
|
||||||
|
};
|
||||||
|
|
||||||
|
static void open_thread(lua_State * L) {
|
||||||
|
luaL_newmetatable(L, thread_mt);
|
||||||
|
lua_pushvalue(L, -1);
|
||||||
|
lua_setfield(L, -2, "__index");
|
||||||
|
setfuncs(L, thread_m, 0);
|
||||||
|
|
||||||
|
set_global_function<mk_thread>(L, "thread");
|
||||||
|
set_global_function<thread_pred>(L, "is_thread");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,10 @@ class environment;
|
||||||
class leanlua_state {
|
class leanlua_state {
|
||||||
struct imp;
|
struct imp;
|
||||||
std::shared_ptr<imp> m_ptr;
|
std::shared_ptr<imp> m_ptr;
|
||||||
|
friend class leanlua_thread;
|
||||||
|
friend int state_dostring(lua_State *);
|
||||||
|
friend int mk_thread(lua_State *);
|
||||||
|
friend int thread_wait(lua_State *);
|
||||||
public:
|
public:
|
||||||
leanlua_state();
|
leanlua_state();
|
||||||
~leanlua_state();
|
~leanlua_state();
|
||||||
|
@ -38,13 +42,5 @@ public:
|
||||||
The script \c str should not store a reference to the environment \c env.
|
The script \c str should not store a reference to the environment \c env.
|
||||||
*/
|
*/
|
||||||
void dostring(char const * str, environment & env);
|
void dostring(char const * str, environment & env);
|
||||||
|
|
||||||
/**
|
|
||||||
\brief Execute the given script, but copy the values at positions <tt>[first, last]</tt> from the stack of \c src.
|
|
||||||
The values are passed as arguments to the script \c str.
|
|
||||||
The values returned by the script \c str are copied back to the stack of \c src.
|
|
||||||
The result is the number of values returned by the script \c str.
|
|
||||||
*/
|
|
||||||
int dostring(char const * str, lua_State * src, int first, int last);
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
12
tests/lua/threads/th1.lua
Normal file
12
tests/lua/threads/th1.lua
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
f = Const("f")
|
||||||
|
a = Const("a")
|
||||||
|
|
||||||
|
S = State()
|
||||||
|
T = thread(S, [[
|
||||||
|
t = ...
|
||||||
|
g = Const("g")
|
||||||
|
return g(t)
|
||||||
|
]], f(a))
|
||||||
|
|
||||||
|
r = T:wait()
|
||||||
|
print(r)
|
Loading…
Reference in a new issue