refactor(util/script_state): remove support for threads and communication channels from the Lua API, the goal is to keep is simple, and use one Lua state object per thread
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
f17e8a853a
commit
1378fa5cbb
18 changed files with 33 additions and 747 deletions
|
@ -12,9 +12,9 @@ Author: Leonardo de Moura
|
|||
#include "frontends/lean/interactive.h"
|
||||
|
||||
namespace lean {
|
||||
interactive::interactive(environment const & env, io_state const & ios, script_state const & ss, unsigned num_threads,
|
||||
interactive::interactive(environment const & env, io_state const & ios, unsigned num_threads,
|
||||
char const * ack_cmd, char const * snapshot_cmd, char const * restore_cmd, char const * restart_cmd):
|
||||
m_env(env), m_ios(ios), m_ss(ss), m_num_threads(num_threads), m_line(1),
|
||||
m_env(env), m_ios(ios), m_num_threads(num_threads), m_line(1),
|
||||
m_ack_cmd(ack_cmd), m_snapshot_cmd(snapshot_cmd), m_restore_cmd(restore_cmd), m_restart_cmd(restart_cmd) {
|
||||
save_snapshot();
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ interactive::interactive(environment const & env, io_state const & ios, script_s
|
|||
void interactive::parse_block(std::string const & str, char const * strm_name) {
|
||||
if (str.size() > 0) {
|
||||
std::istringstream block(str);
|
||||
parser p(m_env, m_ios, block, strm_name, &m_ss, false, m_num_threads, m_lds, m_eds, m_line);
|
||||
parser p(m_env, m_ios, block, strm_name, false, m_num_threads, m_lds, m_eds, m_line);
|
||||
p();
|
||||
m_env = p.env();
|
||||
m_ios = p.ios();
|
||||
|
|
|
@ -34,7 +34,6 @@ class interactive {
|
|||
std::vector<std::string> m_lines;
|
||||
environment m_env;
|
||||
io_state m_ios;
|
||||
script_state m_ss;
|
||||
unsigned m_num_threads;
|
||||
local_level_decls m_lds;
|
||||
local_expr_decls m_eds;
|
||||
|
@ -47,8 +46,7 @@ class interactive {
|
|||
void save_snapshot();
|
||||
void restore(unsigned new_line, std::string & block);
|
||||
public:
|
||||
interactive(environment const & env, io_state const & ios, script_state const & ss,
|
||||
unsigned num_threads = 1,
|
||||
interactive(environment const & env, io_state const & ios, unsigned num_threads = 1,
|
||||
char const * ack_cmd = "#ACK", char const * snapshot_cmd = "#SNAPSHOT",
|
||||
char const * res_cmd = "#RESTORE", char const * restart_cmd = "#RESTART");
|
||||
environment const & env() const { return m_env; }
|
||||
|
|
|
@ -102,10 +102,10 @@ struct scoped_set_parser {
|
|||
|
||||
parser::parser(environment const & env, io_state const & ios,
|
||||
std::istream & strm, char const * strm_name,
|
||||
script_state * ss, bool use_exceptions, unsigned num_threads,
|
||||
bool use_exceptions, unsigned num_threads,
|
||||
local_level_decls const & lds, local_expr_decls const & eds,
|
||||
unsigned line):
|
||||
m_env(env), m_ios(ios), m_ss(ss),
|
||||
m_env(env), m_ios(ios),
|
||||
m_verbose(true), m_use_exceptions(use_exceptions),
|
||||
m_scanner(strm, strm_name), m_local_level_decls(lds), m_local_decls(eds),
|
||||
m_pos_table(std::make_shared<pos_info_table>()) {
|
||||
|
@ -731,8 +731,6 @@ expr parser::parse_notation(parse_table t, expr * left) {
|
|||
break;
|
||||
}
|
||||
case notation::action_kind::LuaExt:
|
||||
if (!m_ss)
|
||||
throw parser_error("failed to use notation implemented in Lua, parser does not contain a Lua state", p);
|
||||
using_script([&](lua_State * L) {
|
||||
scoped_set_parser scope(L, *this);
|
||||
lua_getglobal(L, a.get_lua_fn().c_str());
|
||||
|
@ -955,8 +953,6 @@ void parser::parse_command() {
|
|||
|
||||
void parser::parse_script(bool as_expr) {
|
||||
m_last_script_pos = pos();
|
||||
if (!m_ss)
|
||||
throw exception("failed to execute Lua script, parser does not have a Lua interpreter");
|
||||
std::string script_code = m_scanner.get_str_val();
|
||||
if (as_expr)
|
||||
script_code = "return " + script_code;
|
||||
|
@ -983,8 +979,6 @@ void parser::parse_imports() {
|
|||
while (curr_is_identifier()) {
|
||||
name f = get_name_val();
|
||||
if (auto it = find_file(f, ".lua")) {
|
||||
if (!m_ss)
|
||||
throw parser_error("invalid import, Lua interpreter is not available", pos());
|
||||
lua_files.push_back(*it);
|
||||
} else if (auto it = find_file(f, ".olean")) {
|
||||
olean_files.push_back(*it);
|
||||
|
@ -995,11 +989,10 @@ void parser::parse_imports() {
|
|||
}
|
||||
}
|
||||
m_env = import_modules(m_env, olean_files.size(), olean_files.data(), m_num_threads, true, m_ios);
|
||||
using_script([&](lua_State *) { // NOLINT
|
||||
m_ss->exec_unprotected([&]() {
|
||||
for (auto const & f : lua_files) {
|
||||
m_ss->import_explicit(f.c_str());
|
||||
}});
|
||||
using_script([&](lua_State * L) {
|
||||
for (auto const & f : lua_files) {
|
||||
to_script_state(L).import_explicit(f.c_str());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1037,29 +1030,27 @@ bool parser::parse_commands() {
|
|||
return !m_found_errors;
|
||||
}
|
||||
|
||||
bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name, script_state * S, bool use_exceptions,
|
||||
bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name, bool use_exceptions,
|
||||
unsigned num_threads) {
|
||||
parser p(env, ios, in, strm_name, S, use_exceptions, num_threads);
|
||||
parser p(env, ios, in, strm_name, use_exceptions, num_threads);
|
||||
bool r = p();
|
||||
ios = p.ios();
|
||||
env = p.env();
|
||||
return r;
|
||||
}
|
||||
|
||||
bool parse_commands(environment & env, io_state & ios, char const * fname, script_state * S, bool use_exceptions, unsigned num_threads) {
|
||||
bool parse_commands(environment & env, io_state & ios, char const * fname, bool use_exceptions, unsigned num_threads) {
|
||||
std::ifstream in(fname);
|
||||
if (in.bad() || in.fail())
|
||||
throw exception(sstream() << "failed to open file '" << fname << "'");
|
||||
return parse_commands(env, ios, in, fname, S, use_exceptions, num_threads);
|
||||
return parse_commands(env, ios, in, fname, use_exceptions, num_threads);
|
||||
}
|
||||
|
||||
static int parse_expr(lua_State * L) {
|
||||
script_state S = to_script_state(L);
|
||||
int nargs = lua_gettop(L);
|
||||
expr r;
|
||||
S.exec_unprotected([&]() {
|
||||
r = get_global_parser(L).parse_expr(nargs == 0 ? 0 : lua_tointeger(L, 1));
|
||||
});
|
||||
r = get_global_parser(L).parse_expr(nargs == 0 ? 0 : lua_tointeger(L, 1));
|
||||
return push_expr(L, r);
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ Author: Leonardo de Moura
|
|||
#include "util/script_state.h"
|
||||
#include "util/name_map.h"
|
||||
#include "util/exception.h"
|
||||
#include "util/thread_script_state.h"
|
||||
#include "kernel/environment.h"
|
||||
#include "kernel/expr_maps.h"
|
||||
#include "library/io_state.h"
|
||||
|
@ -49,7 +50,6 @@ typedef std::unordered_map<unsigned, tactic> hint_table;
|
|||
class parser {
|
||||
environment m_env;
|
||||
io_state m_ios;
|
||||
script_state * m_ss;
|
||||
bool m_verbose;
|
||||
bool m_use_exceptions;
|
||||
bool m_show_errors;
|
||||
|
@ -86,11 +86,10 @@ class parser {
|
|||
void protected_call(std::function<void()> && f, std::function<void()> && sync);
|
||||
template<typename F>
|
||||
typename std::result_of<F(lua_State * L)>::type using_script(F && f) {
|
||||
return m_ss->apply([&](lua_State * L) {
|
||||
set_io_state set1(L, m_ios);
|
||||
set_environment set2(L, m_env);
|
||||
return f(L);
|
||||
});
|
||||
script_state S = get_thread_script_state();
|
||||
set_io_state set1(S.get_state(), m_ios);
|
||||
set_environment set2(S.get_state(), m_env);
|
||||
return f(S.get_state());
|
||||
}
|
||||
|
||||
tag get_tag(expr e);
|
||||
|
@ -134,15 +133,13 @@ class parser {
|
|||
public:
|
||||
parser(environment const & env, io_state const & ios,
|
||||
std::istream & strm, char const * str_name,
|
||||
script_state * ss = nullptr, bool use_exceptions = false,
|
||||
unsigned num_threads = 1,
|
||||
bool use_exceptions = false, unsigned num_threads = 1,
|
||||
local_level_decls const & lds = local_level_decls(),
|
||||
local_expr_decls const & eds = local_expr_decls(),
|
||||
unsigned line = 1);
|
||||
|
||||
environment const & env() const { return m_env; }
|
||||
io_state const & ios() const { return m_ios; }
|
||||
script_state * ss() const { return m_ss; }
|
||||
local_level_decls const & get_local_level_decls() const { return m_local_level_decls; }
|
||||
local_expr_decls const & get_local_expr_decls() const { return m_local_decls; }
|
||||
|
||||
|
@ -266,9 +263,7 @@ public:
|
|||
bool operator()() { return parse_commands(); }
|
||||
};
|
||||
|
||||
bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name,
|
||||
script_state * S, bool use_exceptions, unsigned num_threads);
|
||||
bool parse_commands(environment & env, io_state & ios, char const * fname, script_state * S,
|
||||
bool use_exceptions, unsigned num_threads);
|
||||
bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name, bool use_exceptions, unsigned num_threads);
|
||||
bool parse_commands(environment & env, io_state & ios, char const * fname, bool use_exceptions, unsigned num_threads);
|
||||
void open_parser(lua_State * L);
|
||||
}
|
||||
|
|
|
@ -104,19 +104,6 @@ FOREACH(T ${LEANLUADOCS})
|
|||
COMMAND "./test_single.sh" "${CMAKE_CURRENT_BINARY_DIR}/lean -t 100000" ${T})
|
||||
ENDFOREACH(T)
|
||||
|
||||
# LEAN LUA THREAD TESTS
|
||||
if((${CYGWIN} EQUAL "1") OR (${CMAKE_SYSTEM_NAME} MATCHES "Linux"))
|
||||
if ((NOT (${CMAKE_CXX_COMPILER} MATCHES "clang")) AND (${MULTI_THREAD} MATCHES "ON"))
|
||||
file(GLOB LEANLUATHREADTESTS "${LEAN_SOURCE_DIR}/../tests/lua/threads/*.lua")
|
||||
FOREACH(T ${LEANLUATHREADTESTS})
|
||||
GET_FILENAME_COMPONENT(T_NAME ${T} NAME)
|
||||
add_test(NAME "leanluathreadtest_${T_NAME}"
|
||||
WORKING_DIRECTORY "${LEAN_SOURCE_DIR}/../tests/lua/threads"
|
||||
COMMAND "../test_single.sh" "${CMAKE_CURRENT_BINARY_DIR}/lean -t 100000" ${T})
|
||||
ENDFOREACH(T)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# # Create the script lean.sh
|
||||
# # This is used to create a soft dependency on the Lean executable
|
||||
# # Some rules can only be applied if the lean executable exists,
|
||||
|
|
|
@ -15,6 +15,7 @@ Author: Leonardo de Moura
|
|||
#include "util/interrupt.h"
|
||||
#include "util/script_state.h"
|
||||
#include "util/thread.h"
|
||||
#include "util/thread_script_state.h"
|
||||
#include "util/lean_path.h"
|
||||
#include "kernel/environment.h"
|
||||
#include "kernel/kernel_exception.h"
|
||||
|
@ -188,11 +189,9 @@ int main(int argc, char ** argv) {
|
|||
if (quiet)
|
||||
ios.set_option("verbose", false);
|
||||
|
||||
script_state S;
|
||||
S.apply([&](lua_State * L) {
|
||||
set_global_environment(L, env);
|
||||
set_global_io_state(L, ios);
|
||||
});
|
||||
script_state S = lean::get_thread_script_state();
|
||||
set_global_environment(S.get_state(), env);
|
||||
set_global_io_state(S.get_state(), ios);
|
||||
|
||||
try {
|
||||
bool ok = true;
|
||||
|
@ -207,7 +206,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
if (k == input_kind::Lean) {
|
||||
if (!parse_commands(env, ios, argv[i], &S, false, num_threads))
|
||||
if (!parse_commands(env, ios, argv[i], false, num_threads))
|
||||
ok = false;
|
||||
} else if (k == input_kind::Lua) {
|
||||
try {
|
||||
|
@ -222,7 +221,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
if (ok && interactive && default_k == input_kind::Lean) {
|
||||
signal(SIGINT, on_ctrl_c);
|
||||
lean::interactive in(env, ios, S, num_threads);
|
||||
lean::interactive in(env, ios, num_threads);
|
||||
in(std::cin, "[stdin]");
|
||||
}
|
||||
if (export_objects) {
|
||||
|
|
|
@ -8,7 +8,6 @@ Author: Leonardo de Moura
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
#include "util/thread.h"
|
||||
#include "util/lua.h"
|
||||
#include "util/debug.h"
|
||||
#include "util/exception.h"
|
||||
|
@ -43,7 +42,6 @@ static char g_weak_ptr_key; // key for Lua registry (used at get_weak_ptr and sa
|
|||
|
||||
struct script_state::imp {
|
||||
lua_State * m_state;
|
||||
mutex m_mutex;
|
||||
std::unordered_set<std::string> m_imported_modules;
|
||||
|
||||
static std::weak_ptr<imp> * get_weak_ptr(lua_State * L) {
|
||||
|
@ -104,12 +102,10 @@ struct script_state::imp {
|
|||
}
|
||||
|
||||
void dofile(char const * fname) {
|
||||
lock_guard<mutex> lock(m_mutex);
|
||||
::lean::dofile(m_state, fname);
|
||||
}
|
||||
|
||||
void dostring(char const * str) {
|
||||
lock_guard<mutex> lock(m_mutex);
|
||||
::lean::dostring(m_state, str);
|
||||
}
|
||||
|
||||
|
@ -166,426 +162,10 @@ bool script_state::import_explicit(char const * str) {
|
|||
return m_ptr->import_explicit(str);
|
||||
}
|
||||
|
||||
mutex & script_state::get_mutex() {
|
||||
return m_ptr->m_mutex;
|
||||
}
|
||||
|
||||
lua_State * script_state::get_state() {
|
||||
return m_ptr->m_state;
|
||||
}
|
||||
|
||||
constexpr char const * state_mt = "luastate.mt";
|
||||
|
||||
bool is_state(lua_State * L, int idx) {
|
||||
return testudata(L, idx, state_mt);
|
||||
}
|
||||
|
||||
script_state & to_state(lua_State * L, int idx) {
|
||||
return *static_cast<script_state*>(luaL_checkudata(L, idx, state_mt));
|
||||
}
|
||||
|
||||
int push_state(lua_State * L, script_state const & s) {
|
||||
void * mem = lua_newuserdata(L, sizeof(script_state));
|
||||
new (mem) script_state(s);
|
||||
luaL_getmetatable(L, state_mt);
|
||||
lua_setmetatable(L, -2);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int mk_state(lua_State * L) {
|
||||
script_state r;
|
||||
return push_state(L, r);
|
||||
}
|
||||
|
||||
static int state_gc(lua_State * L) {
|
||||
to_state(L, 1).~script_state();
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int writer(lua_State *, void const * p, size_t sz, void * buf) {
|
||||
buffer<char> & _buf = *static_cast<buffer<char>*>(buf);
|
||||
char const * in = static_cast<char const *>(p);
|
||||
for (size_t i = 0; i < sz; i++)
|
||||
_buf.push_back(in[i]);
|
||||
return 0;
|
||||
}
|
||||
struct reader_data {
|
||||
buffer<char> & m_buffer;
|
||||
bool m_done;
|
||||
reader_data(buffer<char> & b):m_buffer(b), m_done(false) {}
|
||||
};
|
||||
static char const * reader(lua_State *, void * data, size_t * sz) {
|
||||
reader_data & _data = *static_cast<reader_data*>(data);
|
||||
if (_data.m_done) {
|
||||
*sz = 0;
|
||||
return nullptr;
|
||||
} else {
|
||||
*sz = _data.m_buffer.size();
|
||||
_data.m_done = true;
|
||||
return _data.m_buffer.data();
|
||||
}
|
||||
}
|
||||
|
||||
static void copy_values(lua_State * src, int first, int last, lua_State * tgt) {
|
||||
for (int i = first; i <= last; i++) {
|
||||
switch (lua_type(src, i)) {
|
||||
case LUA_TNUMBER: lua_pushnumber(tgt, lua_tonumber(src, i)); break;
|
||||
case LUA_TSTRING: lua_pushstring(tgt, lua_tostring(src, i)); break;
|
||||
case LUA_TNIL: lua_pushnil(tgt); break;
|
||||
case LUA_TBOOLEAN: lua_pushboolean(tgt, lua_toboolean(src, i)); break;
|
||||
case LUA_TFUNCTION: {
|
||||
lua_pushvalue(src, i); // copy function to the top of the stack
|
||||
buffer<char> buffer;
|
||||
if (lua_dump(src, writer, &buffer) != 0)
|
||||
throw exception("falied to copy function between State objects");
|
||||
lua_pop(src, 1); // remove function from the top of the stack
|
||||
reader_data data(buffer);
|
||||
if (load(tgt, reader, &data, "temporary buffer for moving functions between states") != 0)
|
||||
throw exception("falied to copy function between State objects");
|
||||
// copy upvalues
|
||||
int j = 1;
|
||||
while (true) {
|
||||
char const * name = lua_getupvalue(src, i, j);
|
||||
if (name == nullptr)
|
||||
break;
|
||||
copy_values(src, lua_gettop(src), lua_gettop(src), tgt); // copy upvalue to tgt stack
|
||||
lua_pop(src, 1); // remove upvalue from src stack
|
||||
lua_setupvalue(tgt, -2, j);
|
||||
j++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case LUA_TUSERDATA:
|
||||
if (lua_migrate_fn f = get_migrate_fn(src, i)) {
|
||||
f(src, i, tgt);
|
||||
} else {
|
||||
throw exception("unsupported value type for inter-State call");
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw exception("unsupported value type for inter-State call");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int state_dostring(lua_State * L) {
|
||||
return to_state(L, 1).apply([&](lua_State * S) {
|
||||
char const * script = luaL_checkstring(L, 2);
|
||||
int first = 3;
|
||||
int last = lua_gettop(L);
|
||||
int sz_before = lua_gettop(S);
|
||||
int status = luaL_loadstring(S, script);
|
||||
if (status)
|
||||
throw script_exception(lua_tostring(S, -1));
|
||||
|
||||
copy_values(L, first, last, S);
|
||||
|
||||
pcall(S, first > last ? 0 : last - first + 1, LUA_MULTRET, 0);
|
||||
|
||||
int sz_after = lua_gettop(S);
|
||||
|
||||
if (sz_after > sz_before) {
|
||||
copy_values(S, sz_before + 1, sz_after, L);
|
||||
lua_pop(S, sz_after - sz_before);
|
||||
}
|
||||
return sz_after - sz_before;
|
||||
});
|
||||
}
|
||||
|
||||
int state_set_global(lua_State * L) {
|
||||
to_state(L, 1).apply([=](lua_State * S) {
|
||||
char const * name = luaL_checkstring(L, 2);
|
||||
copy_values(L, 3, 3, S);
|
||||
lua_setglobal(S, name);
|
||||
});
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int state_pred(lua_State * L) {
|
||||
return push_boolean(L, is_state(L, 1));
|
||||
}
|
||||
|
||||
static const struct luaL_Reg state_m[] = {
|
||||
{"__gc", state_gc},
|
||||
{"dostring", safe_function<state_dostring>},
|
||||
{"eval", safe_function<state_dostring>},
|
||||
{"set", safe_function<state_set_global>},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
static void open_state(lua_State * L) {
|
||||
luaL_newmetatable(L, state_mt);
|
||||
lua_pushvalue(L, -1);
|
||||
lua_setfield(L, -2, "__index");
|
||||
setfuncs(L, state_m, 0);
|
||||
|
||||
SET_GLOBAL_FUN(mk_state, "State");
|
||||
SET_GLOBAL_FUN(state_pred, "is_State");
|
||||
}
|
||||
|
||||
// TODO(Leo): allow the user to change it?
|
||||
#define SMALL_DELAY 10 // in ms
|
||||
chrono::milliseconds g_small_delay(SMALL_DELAY);
|
||||
|
||||
#if defined(LEAN_MULTI_THREAD)
|
||||
/**
|
||||
\brief Channel for communicating with thread objects in the Lua API
|
||||
*/
|
||||
class data_channel {
|
||||
// We use a lua_State to implement the channel. This is quite hackish,
|
||||
// but it is a convenient storage for Lua objects sent from one state to
|
||||
// another.
|
||||
script_state m_channel;
|
||||
int m_ini;
|
||||
mutex m_mutex;
|
||||
condition_variable m_cv;
|
||||
public:
|
||||
data_channel() {
|
||||
lua_State * channel = m_channel.m_ptr->m_state;
|
||||
m_ini = lua_gettop(channel);
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Copy elements from positions [first, last] from src stack
|
||||
to the channel.
|
||||
*/
|
||||
void write(lua_State * src, int first, int last) {
|
||||
// write the object on the top of the stack of src to the table
|
||||
// on m_channel.
|
||||
if (last < first)
|
||||
return;
|
||||
lock_guard<mutex> lock(m_mutex);
|
||||
lua_State * channel = m_channel.m_ptr->m_state;
|
||||
bool was_empty = lua_gettop(channel) == m_ini;
|
||||
copy_values(src, first, last, channel);
|
||||
if (was_empty)
|
||||
m_cv.notify_one();
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Retrieve one element from the channel. It will block
|
||||
the execution of \c tgt if the channel is empty.
|
||||
*/
|
||||
int read(lua_State * tgt, int i) {
|
||||
unique_lock<mutex> lock(m_mutex);
|
||||
lua_State * channel = m_channel.m_ptr->m_state;
|
||||
if (i > 0) {
|
||||
// i is the position of the timeout argument
|
||||
chrono::milliseconds dura(luaL_checkinteger(tgt, i));
|
||||
if (lua_gettop(channel) == m_ini)
|
||||
m_cv.wait_for(lock, dura);
|
||||
if (lua_gettop(channel) == m_ini) {
|
||||
// timeout...
|
||||
lua_pushboolean(tgt, false);
|
||||
lua_pushnil(tgt);
|
||||
return 2;
|
||||
} else {
|
||||
lua_pushboolean(tgt, true);
|
||||
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
|
||||
lua_remove(channel, m_ini + 1);
|
||||
return 2;
|
||||
}
|
||||
} else {
|
||||
while (lua_gettop(channel) == m_ini) {
|
||||
check_interrupted();
|
||||
m_cv.wait_for(lock, g_small_delay);
|
||||
}
|
||||
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
|
||||
lua_remove(channel, m_ini + 1);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
\brief We want the channels to be lazily created.
|
||||
*/
|
||||
class data_channel_ref {
|
||||
std::unique_ptr<data_channel> m_channel;
|
||||
mutex m_mutex;
|
||||
public:
|
||||
data_channel & get() {
|
||||
lock_guard<mutex> lock(m_mutex);
|
||||
if (!m_channel)
|
||||
m_channel.reset(new data_channel());
|
||||
lean_assert(m_channel);
|
||||
return *m_channel;
|
||||
}
|
||||
};
|
||||
|
||||
data_channel_ref g_in_channel;
|
||||
data_channel_ref g_out_channel;
|
||||
|
||||
int channel_read(lua_State * L) {
|
||||
return g_in_channel.get().read(L, lua_gettop(L));
|
||||
}
|
||||
|
||||
int channel_write(lua_State * L) {
|
||||
g_out_channel.get().write(L, 1, lua_gettop(L));
|
||||
return 0;
|
||||
}
|
||||
|
||||
class leanlua_thread {
|
||||
script_state m_state;
|
||||
int m_sz_before;
|
||||
std::unique_ptr<exception> m_exception;
|
||||
atomic<data_channel_ref *> m_in_channel_addr;
|
||||
atomic<data_channel_ref *> m_out_channel_addr;
|
||||
interruptible_thread m_thread;
|
||||
public:
|
||||
leanlua_thread(script_state const & st, int sz_before, int num_args):
|
||||
m_state(st),
|
||||
m_sz_before(sz_before),
|
||||
m_in_channel_addr(0),
|
||||
m_out_channel_addr(0),
|
||||
m_thread([=]() {
|
||||
m_in_channel_addr.store(&g_in_channel);
|
||||
m_out_channel_addr.store(&g_out_channel);
|
||||
m_state.apply([&](lua_State * S) {
|
||||
int result = lua_pcall(S, num_args, LUA_MULTRET, 0);
|
||||
if (result) {
|
||||
if (is_exception(S, -1))
|
||||
m_exception.reset(to_exception(S, -1).clone());
|
||||
else
|
||||
m_exception.reset(new script_exception(lua_tostring(S, -1)));
|
||||
}
|
||||
});
|
||||
}) {
|
||||
}
|
||||
|
||||
~leanlua_thread() {
|
||||
if (m_thread.joinable())
|
||||
m_thread.join();
|
||||
}
|
||||
|
||||
int copy_result(lua_State * src) {
|
||||
if (m_exception)
|
||||
m_exception->rethrow();
|
||||
return m_state.apply([&](lua_State * S) {
|
||||
int sz_after = lua_gettop(S);
|
||||
if (sz_after > m_sz_before) {
|
||||
copy_values(S, m_sz_before + 1, sz_after, src);
|
||||
lua_pop(S, sz_after - m_sz_before);
|
||||
}
|
||||
return sz_after - m_sz_before;
|
||||
});
|
||||
}
|
||||
|
||||
void wait() {
|
||||
m_thread.join();
|
||||
}
|
||||
|
||||
void request_interrupt() {
|
||||
m_thread.request_interrupt();
|
||||
}
|
||||
|
||||
void write(lua_State * src, int first, int last) {
|
||||
while (!m_in_channel_addr) {
|
||||
check_interrupted();
|
||||
this_thread::sleep_for(g_small_delay);
|
||||
}
|
||||
data_channel & in = m_in_channel_addr.load()->get();
|
||||
in.write(src, first, last);
|
||||
}
|
||||
|
||||
int read(lua_State * src) {
|
||||
if (!m_out_channel_addr) {
|
||||
check_interrupted();
|
||||
this_thread::sleep_for(g_small_delay);
|
||||
}
|
||||
data_channel & out = m_out_channel_addr.load()->get();
|
||||
int nargs = lua_gettop(src);
|
||||
return out.read(src, nargs == 1 ? 0 : 2);
|
||||
}
|
||||
|
||||
bool started() {
|
||||
return m_in_channel_addr && m_out_channel_addr;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr char const * thread_mt = "thread.mt";
|
||||
|
||||
bool is_thread(lua_State * L, int idx) {
|
||||
return testudata(L, idx, thread_mt);
|
||||
}
|
||||
|
||||
leanlua_thread & to_thread(lua_State * L, int idx) {
|
||||
return *static_cast<leanlua_thread*>(luaL_checkudata(L, idx, thread_mt));
|
||||
}
|
||||
|
||||
int mk_thread(lua_State * L) {
|
||||
check_threadsafe();
|
||||
script_state & st = to_state(L, 1);
|
||||
char const * script = luaL_checkstring(L, 2);
|
||||
int first = 3;
|
||||
int last = lua_gettop(L);
|
||||
int nargs = first > last ? 0 : last - first + 1;
|
||||
int sz_before;
|
||||
st.apply([&](lua_State * S) {
|
||||
sz_before = lua_gettop(S);
|
||||
int result = luaL_loadstring(S, script);
|
||||
if (result)
|
||||
throw script_exception(lua_tostring(S, -1));
|
||||
copy_values(L, first, last, S);
|
||||
});
|
||||
void * mem = lua_newuserdata(L, sizeof(leanlua_thread));
|
||||
new (mem) leanlua_thread(st, sz_before, nargs);
|
||||
luaL_getmetatable(L, thread_mt);
|
||||
lua_setmetatable(L, -2);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int thread_gc(lua_State * L) {
|
||||
to_thread(L, 1).~leanlua_thread();
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int thread_pred(lua_State * L) {
|
||||
return push_boolean(L, is_thread(L, 1));
|
||||
}
|
||||
|
||||
static int thread_write(lua_State * L) {
|
||||
to_thread(L, 1).write(L, 2, lua_gettop(L));
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int thread_read(lua_State * L) {
|
||||
return to_thread(L, 1).read(L);
|
||||
}
|
||||
|
||||
static int thread_interrupt(lua_State * L) {
|
||||
to_thread(L, 1).request_interrupt();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int thread_wait(lua_State * L) {
|
||||
auto & t = to_thread(L, 1);
|
||||
script_state st = to_script_state(L);
|
||||
st.exec_unprotected([&]() { t.wait(); });
|
||||
return t.copy_result(L);
|
||||
}
|
||||
|
||||
static const struct luaL_Reg thread_m[] = {
|
||||
{"__gc", thread_gc},
|
||||
{"wait", safe_function<thread_wait>},
|
||||
{"join", safe_function<thread_wait>},
|
||||
{"interrupt", safe_function<thread_interrupt>},
|
||||
{"write", safe_function<thread_write>},
|
||||
{"read", safe_function<thread_read>},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
static void open_thread(lua_State * L) {
|
||||
luaL_newmetatable(L, thread_mt);
|
||||
lua_pushvalue(L, -1);
|
||||
lua_setfield(L, -2, "__index");
|
||||
setfuncs(L, thread_m, 0);
|
||||
|
||||
SET_GLOBAL_FUN(mk_thread, "thread");
|
||||
SET_GLOBAL_FUN(thread_pred, "is_thread");
|
||||
}
|
||||
#endif
|
||||
|
||||
static int check_interrupted(lua_State *) { // NOLINT
|
||||
check_interrupted();
|
||||
return 0;
|
||||
|
@ -611,27 +191,14 @@ static int yield(lua_State * L) {
|
|||
static int import(lua_State * L) {
|
||||
std::string fname = luaL_checkstring(L, 1);
|
||||
script_state s = to_script_state(L);
|
||||
s.exec_unprotected([&]() { s.import(fname.c_str()); });
|
||||
s.import(fname.c_str());
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void open_interrupt(lua_State * L) {
|
||||
void open_extra(lua_State * L) {
|
||||
SET_GLOBAL_FUN(check_interrupted, "check_interrupted");
|
||||
SET_GLOBAL_FUN(sleep, "sleep");
|
||||
SET_GLOBAL_FUN(yield, "yield");
|
||||
#if defined(LEAN_MULTI_THREAD)
|
||||
SET_GLOBAL_FUN(channel_read, "read");
|
||||
SET_GLOBAL_FUN(channel_write, "write");
|
||||
#endif
|
||||
}
|
||||
|
||||
void open_extra(lua_State * L) {
|
||||
open_state(L);
|
||||
#if defined(LEAN_MULTI_THREAD)
|
||||
open_thread(L);
|
||||
#endif
|
||||
open_interrupt(L);
|
||||
|
||||
SET_GLOBAL_FUN(import, "import");
|
||||
SET_GLOBAL_FUN(import, "import");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,6 @@ Author: Leonardo de Moura
|
|||
#pragma once
|
||||
#include <memory>
|
||||
#include <lua.hpp>
|
||||
#include "util/thread.h"
|
||||
#include "util/unlock_guard.h"
|
||||
|
||||
namespace lean {
|
||||
/**
|
||||
|
@ -20,8 +18,6 @@ public:
|
|||
private:
|
||||
std::shared_ptr<imp> m_ptr;
|
||||
friend script_state to_script_state(lua_State * L);
|
||||
mutex & get_mutex();
|
||||
lua_State * get_state();
|
||||
friend class data_channel;
|
||||
public:
|
||||
static void set_check_interrupt_freq(unsigned count);
|
||||
|
@ -55,34 +51,10 @@ public:
|
|||
*/
|
||||
bool import_explicit(char const * fname);
|
||||
|
||||
/**
|
||||
\brief Execute \c f in the using the internal Lua State.
|
||||
*/
|
||||
template<typename F>
|
||||
typename std::result_of<F(lua_State * L)>::type apply(F && f) {
|
||||
lock_guard<mutex> lock(get_mutex());
|
||||
return f(get_state());
|
||||
}
|
||||
lua_State * get_state();
|
||||
|
||||
typedef void (*reg_fn)(lua_State *); // NOLINT
|
||||
static void register_module(reg_fn f);
|
||||
|
||||
/**
|
||||
\brief Auxiliary function for writing API bindings
|
||||
that release the lock to this object while executing
|
||||
\c f.
|
||||
*/
|
||||
template<typename F>
|
||||
void exec_unprotected(F && f) {
|
||||
unlock_guard unlock(get_mutex());
|
||||
f();
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
void exec_protected(F && f) {
|
||||
lock_guard<mutex> lock(get_mutex());
|
||||
f();
|
||||
}
|
||||
};
|
||||
/**
|
||||
\brief Return a reference to the script_state object that is wrapping \c L.
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
s = State()
|
||||
s:eval([[ y = 5; print(x); print(y); ]])
|
||||
s:set("x", 10)
|
||||
s:set("y", nil)
|
||||
s:eval([[ print(x); print(y); ]])
|
|
@ -1,11 +0,0 @@
|
|||
-- Execute f, and make sure is throws an error
|
||||
function check_error(f)
|
||||
ok, msg = pcall(function ()
|
||||
f()
|
||||
end)
|
||||
if ok then
|
||||
error("unexpected success...")
|
||||
else
|
||||
print("caught expected error: ", msg)
|
||||
end
|
||||
end
|
|
@ -1,11 +0,0 @@
|
|||
|
||||
S = State()
|
||||
T = thread(S, [[
|
||||
sleep(10000)
|
||||
]])
|
||||
|
||||
T:interrupt()
|
||||
local ok, msg = pcall(function() T:wait() end)
|
||||
assert(not ok)
|
||||
assert(is_exception(msg))
|
||||
print(msg:what():find("interrupted"))
|
|
@ -1,12 +0,0 @@
|
|||
f = Const("f")
|
||||
a = Const("a")
|
||||
|
||||
S = State()
|
||||
T = thread(S, [[
|
||||
t = ...
|
||||
g = Const("g")
|
||||
return g(t)
|
||||
]], f(a))
|
||||
|
||||
r = T:wait()
|
||||
print(r)
|
|
@ -1,33 +0,0 @@
|
|||
S1 = State()
|
||||
S2 = State()
|
||||
code = [[
|
||||
function f(env, prefix, num, type)
|
||||
local r = list_certified_declaration()
|
||||
for i = 1, num do
|
||||
r = list_certified_declaration(type_check(env, mk_var_decl(prefix .. "_" .. i, type)), r)
|
||||
end
|
||||
return r
|
||||
end
|
||||
]]
|
||||
S1:dostring(code)
|
||||
S2:dostring(code)
|
||||
local e = bare_environment()
|
||||
e = add_decl(e, mk_var_decl("N", Type))
|
||||
code2 = [[
|
||||
e, prefix = ...
|
||||
return f(e, prefix, 1000, Const("N"))
|
||||
]]
|
||||
T1 = thread(S1, code2, e, "x")
|
||||
T2 = thread(S2, code2, e, "y")
|
||||
local r1 = T1:wait()
|
||||
local r2 = T2:wait()
|
||||
while not r1:is_nil() do
|
||||
e = e:add(r1:car())
|
||||
r1 = r1:cdr()
|
||||
end
|
||||
while not r2:is_nil() do
|
||||
e = e:add(r2:car())
|
||||
r2 = r2:cdr()
|
||||
end
|
||||
assert(e:find("x_" .. 10))
|
||||
assert(e:find("y_" .. 100))
|
|
@ -1,12 +0,0 @@
|
|||
S1 = State()
|
||||
S2 = State()
|
||||
code = [[
|
||||
id = ...
|
||||
for i = 1, 10000 do
|
||||
print("id: " .. id .. ", val: " .. i)
|
||||
end
|
||||
]]
|
||||
T1 = thread(S1, code, 1)
|
||||
T2 = thread(S2, code, 2)
|
||||
T1:wait()
|
||||
T2:wait()
|
|
@ -1,58 +0,0 @@
|
|||
-- Create a nested lua_State object
|
||||
S = State()
|
||||
|
||||
-- Remarks:
|
||||
-- '[[ ... ]]' is a multi-line string in Lua
|
||||
-- obj:method(args) is the syntax for invoking a method
|
||||
-- it is actually syntax sugar for
|
||||
-- obj.method(obj, args)
|
||||
S:dostring([[
|
||||
x = 10
|
||||
]])
|
||||
|
||||
-- Variable x is not visible in the main State object
|
||||
print(x) -- it will print nil
|
||||
S:dostring([[
|
||||
print(x)
|
||||
]])
|
||||
|
||||
-- Remark: '...' is a reference to varargs in Lua
|
||||
-- We can pass arguments to/from a nested state
|
||||
|
||||
-- The following statement passes 10 and 20 as arguments
|
||||
-- to the nested lua_State object S.
|
||||
-- The values returned by the script are stored in r1 and r2
|
||||
r1, r2 = S:dostring([[
|
||||
-- extract arguments passed to dostring
|
||||
a1, a2 = ...
|
||||
return a1 + a2, a1 - a2
|
||||
]], 10, 20)
|
||||
print("r1:", r1)
|
||||
print("r2:", r2)
|
||||
|
||||
-- We can communicate integers, strings and Lean objects
|
||||
f = Const("f")
|
||||
a = Const("a")
|
||||
T = S:dostring([[
|
||||
t = ...
|
||||
g = Const("g")
|
||||
return g(g(g(t)))
|
||||
]], f(a))
|
||||
print(T)
|
||||
|
||||
-- We can also execute commands in a separate thread.
|
||||
-- The following command creates a thread for running
|
||||
-- the given script in the state S.
|
||||
-- It does not wait the thread to finish.
|
||||
T = thread(S, [[
|
||||
t = ...
|
||||
g = Const("g")
|
||||
b = Const("b")
|
||||
return g(b, t)
|
||||
]], f(a))
|
||||
|
||||
-- The method wait makes us wait for the thread T.
|
||||
-- It return the values returned by the script.
|
||||
r = T:wait()
|
||||
-- It will print the Lean expression g(b, f(a))
|
||||
print(r)
|
|
@ -1,34 +0,0 @@
|
|||
S = State()
|
||||
assert(is_State(S))
|
||||
S:eval([[ local x = ...; assert(x == 10) ]], 10)
|
||||
S:eval([[ local f = ...; assert(f) ]], true)
|
||||
S:eval([[ local f = ...; assert(not f) ]], false)
|
||||
S:eval([[ local s = ...; assert(s == "foo") ]], "foo")
|
||||
S:eval([[ local f = ...; assert(f(1) == 3) ]], function (x) return x + 2 end)
|
||||
local val = 10
|
||||
S:eval([[ local f = ...; assert(f(1) == 11) ]], function (x) return x + val end)
|
||||
S:eval([[ local o = ...; assert(o == name("a")) ]], name("a"))
|
||||
S:eval([[ local o = ...; assert(o == Const("a")) ]], Const("a"))
|
||||
S:eval([[ local o = ...; assert(is_environment(o)) ]], bare_environment())
|
||||
S:eval([[ local o = ...; assert(o == mpz(100)) ]], mpz(100))
|
||||
S:eval([[ local o = ...; assert(o == mpq(100)/3) ]], mpq(100)/3)
|
||||
S:eval([[ local o = ...; assert(is_options(o)) ]], options())
|
||||
S:eval([[ local o = ...; assert(is_sexpr(o)) ]], sexpr())
|
||||
S:eval([[ local o = ...; assert(o:is_cons()) ]], sexpr(1, 2))
|
||||
S:eval([[ local o = ...; assert(is_format(o)) ]], format("1"))
|
||||
S:eval([[ local o1, o2, o3 = ...; assert(is_sexpr(o1)); assert(is_name(o2)); assert(o3 == 10) ]], sexpr(), name("foo"), 10)
|
||||
assert(not pcall(function() S:eval([[ x = ]]) end))
|
||||
local T = thread(S, [[ local x = ...; return x + 10, x - 10 ]], 10)
|
||||
assert(is_thread(T))
|
||||
local r1, r2 = T:wait()
|
||||
assert(r1 == 20)
|
||||
assert(r2 == 0)
|
||||
assert(not pcall(function() S:eval([[ x = ]]) end))
|
||||
local T2 = thread(S, [[ local x = ...; error("failed") ]], 10)
|
||||
local ok, msg = pcall(function() T2:wait() end)
|
||||
assert(not ok)
|
||||
assert(is_exception(msg))
|
||||
print(msg:what())
|
||||
assert(msg:what():find("failed"))
|
||||
local T3 = thread(S, [[ local x = ...; return x + 10, x - 10 ]], 10)
|
||||
T3 = nil
|
|
@ -1,10 +0,0 @@
|
|||
S = State()
|
||||
T = thread(S, [[
|
||||
while true do
|
||||
check_interrupted()
|
||||
end
|
||||
]])
|
||||
|
||||
sleep(100)
|
||||
T:interrupt()
|
||||
assert(not pcall(function() T:wait() end))
|
|
@ -1,37 +0,0 @@
|
|||
S = State()
|
||||
T = thread(S, [[
|
||||
print("starting thread...")
|
||||
pcall(function()
|
||||
while true do
|
||||
check_interrupted() -- check if thread was interrupted
|
||||
local ok, val = read(10) -- 10 ms timeout
|
||||
if ok then
|
||||
print("thread received:", val)
|
||||
write(val + 10, val - 10) -- send result back to main thread
|
||||
end
|
||||
end
|
||||
end)
|
||||
]])
|
||||
|
||||
for i = 1, 10 do
|
||||
T:write(10 * i)
|
||||
local r1 = T:read()
|
||||
local r2 = T:read()
|
||||
print("main received: ", r1, r2)
|
||||
end
|
||||
T:interrupt()
|
||||
print("done")
|
||||
|
||||
-- Channels are quite flexible, we can send closure over the channel
|
||||
T = thread(S, [[
|
||||
for i = 1, 10 do
|
||||
local f = read()
|
||||
-- send back the result of f(i)
|
||||
write(f(i))
|
||||
end
|
||||
]])
|
||||
|
||||
for i = 1, 10 do
|
||||
T:write(function (x) return x + i end)
|
||||
print(T:read())
|
||||
end
|
Loading…
Add table
Reference in a new issue