refactor(util/script_state): remove support for threads and communication channels from the Lua API, the goal is to keep is simple, and use one Lua state object per thread

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-17 21:38:10 -07:00
parent f17e8a853a
commit 1378fa5cbb
18 changed files with 33 additions and 747 deletions

View file

@ -12,9 +12,9 @@ Author: Leonardo de Moura
#include "frontends/lean/interactive.h"
namespace lean {
interactive::interactive(environment const & env, io_state const & ios, script_state const & ss, unsigned num_threads,
interactive::interactive(environment const & env, io_state const & ios, unsigned num_threads,
char const * ack_cmd, char const * snapshot_cmd, char const * restore_cmd, char const * restart_cmd):
m_env(env), m_ios(ios), m_ss(ss), m_num_threads(num_threads), m_line(1),
m_env(env), m_ios(ios), m_num_threads(num_threads), m_line(1),
m_ack_cmd(ack_cmd), m_snapshot_cmd(snapshot_cmd), m_restore_cmd(restore_cmd), m_restart_cmd(restart_cmd) {
save_snapshot();
}
@ -22,7 +22,7 @@ interactive::interactive(environment const & env, io_state const & ios, script_s
void interactive::parse_block(std::string const & str, char const * strm_name) {
if (str.size() > 0) {
std::istringstream block(str);
parser p(m_env, m_ios, block, strm_name, &m_ss, false, m_num_threads, m_lds, m_eds, m_line);
parser p(m_env, m_ios, block, strm_name, false, m_num_threads, m_lds, m_eds, m_line);
p();
m_env = p.env();
m_ios = p.ios();

View file

@ -34,7 +34,6 @@ class interactive {
std::vector<std::string> m_lines;
environment m_env;
io_state m_ios;
script_state m_ss;
unsigned m_num_threads;
local_level_decls m_lds;
local_expr_decls m_eds;
@ -47,8 +46,7 @@ class interactive {
void save_snapshot();
void restore(unsigned new_line, std::string & block);
public:
interactive(environment const & env, io_state const & ios, script_state const & ss,
unsigned num_threads = 1,
interactive(environment const & env, io_state const & ios, unsigned num_threads = 1,
char const * ack_cmd = "#ACK", char const * snapshot_cmd = "#SNAPSHOT",
char const * res_cmd = "#RESTORE", char const * restart_cmd = "#RESTART");
environment const & env() const { return m_env; }

View file

@ -102,10 +102,10 @@ struct scoped_set_parser {
parser::parser(environment const & env, io_state const & ios,
std::istream & strm, char const * strm_name,
script_state * ss, bool use_exceptions, unsigned num_threads,
bool use_exceptions, unsigned num_threads,
local_level_decls const & lds, local_expr_decls const & eds,
unsigned line):
m_env(env), m_ios(ios), m_ss(ss),
m_env(env), m_ios(ios),
m_verbose(true), m_use_exceptions(use_exceptions),
m_scanner(strm, strm_name), m_local_level_decls(lds), m_local_decls(eds),
m_pos_table(std::make_shared<pos_info_table>()) {
@ -731,8 +731,6 @@ expr parser::parse_notation(parse_table t, expr * left) {
break;
}
case notation::action_kind::LuaExt:
if (!m_ss)
throw parser_error("failed to use notation implemented in Lua, parser does not contain a Lua state", p);
using_script([&](lua_State * L) {
scoped_set_parser scope(L, *this);
lua_getglobal(L, a.get_lua_fn().c_str());
@ -955,8 +953,6 @@ void parser::parse_command() {
void parser::parse_script(bool as_expr) {
m_last_script_pos = pos();
if (!m_ss)
throw exception("failed to execute Lua script, parser does not have a Lua interpreter");
std::string script_code = m_scanner.get_str_val();
if (as_expr)
script_code = "return " + script_code;
@ -983,8 +979,6 @@ void parser::parse_imports() {
while (curr_is_identifier()) {
name f = get_name_val();
if (auto it = find_file(f, ".lua")) {
if (!m_ss)
throw parser_error("invalid import, Lua interpreter is not available", pos());
lua_files.push_back(*it);
} else if (auto it = find_file(f, ".olean")) {
olean_files.push_back(*it);
@ -995,11 +989,10 @@ void parser::parse_imports() {
}
}
m_env = import_modules(m_env, olean_files.size(), olean_files.data(), m_num_threads, true, m_ios);
using_script([&](lua_State *) { // NOLINT
m_ss->exec_unprotected([&]() {
for (auto const & f : lua_files) {
m_ss->import_explicit(f.c_str());
}});
using_script([&](lua_State * L) {
for (auto const & f : lua_files) {
to_script_state(L).import_explicit(f.c_str());
}
});
}
@ -1037,29 +1030,27 @@ bool parser::parse_commands() {
return !m_found_errors;
}
bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name, script_state * S, bool use_exceptions,
bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name, bool use_exceptions,
unsigned num_threads) {
parser p(env, ios, in, strm_name, S, use_exceptions, num_threads);
parser p(env, ios, in, strm_name, use_exceptions, num_threads);
bool r = p();
ios = p.ios();
env = p.env();
return r;
}
bool parse_commands(environment & env, io_state & ios, char const * fname, script_state * S, bool use_exceptions, unsigned num_threads) {
bool parse_commands(environment & env, io_state & ios, char const * fname, bool use_exceptions, unsigned num_threads) {
std::ifstream in(fname);
if (in.bad() || in.fail())
throw exception(sstream() << "failed to open file '" << fname << "'");
return parse_commands(env, ios, in, fname, S, use_exceptions, num_threads);
return parse_commands(env, ios, in, fname, use_exceptions, num_threads);
}
static int parse_expr(lua_State * L) {
script_state S = to_script_state(L);
int nargs = lua_gettop(L);
expr r;
S.exec_unprotected([&]() {
r = get_global_parser(L).parse_expr(nargs == 0 ? 0 : lua_tointeger(L, 1));
});
r = get_global_parser(L).parse_expr(nargs == 0 ? 0 : lua_tointeger(L, 1));
return push_expr(L, r);
}

View file

@ -11,6 +11,7 @@ Author: Leonardo de Moura
#include "util/script_state.h"
#include "util/name_map.h"
#include "util/exception.h"
#include "util/thread_script_state.h"
#include "kernel/environment.h"
#include "kernel/expr_maps.h"
#include "library/io_state.h"
@ -49,7 +50,6 @@ typedef std::unordered_map<unsigned, tactic> hint_table;
class parser {
environment m_env;
io_state m_ios;
script_state * m_ss;
bool m_verbose;
bool m_use_exceptions;
bool m_show_errors;
@ -86,11 +86,10 @@ class parser {
void protected_call(std::function<void()> && f, std::function<void()> && sync);
template<typename F>
typename std::result_of<F(lua_State * L)>::type using_script(F && f) {
return m_ss->apply([&](lua_State * L) {
set_io_state set1(L, m_ios);
set_environment set2(L, m_env);
return f(L);
});
script_state S = get_thread_script_state();
set_io_state set1(S.get_state(), m_ios);
set_environment set2(S.get_state(), m_env);
return f(S.get_state());
}
tag get_tag(expr e);
@ -134,15 +133,13 @@ class parser {
public:
parser(environment const & env, io_state const & ios,
std::istream & strm, char const * str_name,
script_state * ss = nullptr, bool use_exceptions = false,
unsigned num_threads = 1,
bool use_exceptions = false, unsigned num_threads = 1,
local_level_decls const & lds = local_level_decls(),
local_expr_decls const & eds = local_expr_decls(),
unsigned line = 1);
environment const & env() const { return m_env; }
io_state const & ios() const { return m_ios; }
script_state * ss() const { return m_ss; }
local_level_decls const & get_local_level_decls() const { return m_local_level_decls; }
local_expr_decls const & get_local_expr_decls() const { return m_local_decls; }
@ -266,9 +263,7 @@ public:
bool operator()() { return parse_commands(); }
};
bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name,
script_state * S, bool use_exceptions, unsigned num_threads);
bool parse_commands(environment & env, io_state & ios, char const * fname, script_state * S,
bool use_exceptions, unsigned num_threads);
bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name, bool use_exceptions, unsigned num_threads);
bool parse_commands(environment & env, io_state & ios, char const * fname, bool use_exceptions, unsigned num_threads);
void open_parser(lua_State * L);
}

View file

@ -104,19 +104,6 @@ FOREACH(T ${LEANLUADOCS})
COMMAND "./test_single.sh" "${CMAKE_CURRENT_BINARY_DIR}/lean -t 100000" ${T})
ENDFOREACH(T)
# LEAN LUA THREAD TESTS
if((${CYGWIN} EQUAL "1") OR (${CMAKE_SYSTEM_NAME} MATCHES "Linux"))
if ((NOT (${CMAKE_CXX_COMPILER} MATCHES "clang")) AND (${MULTI_THREAD} MATCHES "ON"))
file(GLOB LEANLUATHREADTESTS "${LEAN_SOURCE_DIR}/../tests/lua/threads/*.lua")
FOREACH(T ${LEANLUATHREADTESTS})
GET_FILENAME_COMPONENT(T_NAME ${T} NAME)
add_test(NAME "leanluathreadtest_${T_NAME}"
WORKING_DIRECTORY "${LEAN_SOURCE_DIR}/../tests/lua/threads"
COMMAND "../test_single.sh" "${CMAKE_CURRENT_BINARY_DIR}/lean -t 100000" ${T})
ENDFOREACH(T)
endif()
endif()
# # Create the script lean.sh
# # This is used to create a soft dependency on the Lean executable
# # Some rules can only be applied if the lean executable exists,

View file

@ -15,6 +15,7 @@ Author: Leonardo de Moura
#include "util/interrupt.h"
#include "util/script_state.h"
#include "util/thread.h"
#include "util/thread_script_state.h"
#include "util/lean_path.h"
#include "kernel/environment.h"
#include "kernel/kernel_exception.h"
@ -188,11 +189,9 @@ int main(int argc, char ** argv) {
if (quiet)
ios.set_option("verbose", false);
script_state S;
S.apply([&](lua_State * L) {
set_global_environment(L, env);
set_global_io_state(L, ios);
});
script_state S = lean::get_thread_script_state();
set_global_environment(S.get_state(), env);
set_global_io_state(S.get_state(), ios);
try {
bool ok = true;
@ -207,7 +206,7 @@ int main(int argc, char ** argv) {
}
}
if (k == input_kind::Lean) {
if (!parse_commands(env, ios, argv[i], &S, false, num_threads))
if (!parse_commands(env, ios, argv[i], false, num_threads))
ok = false;
} else if (k == input_kind::Lua) {
try {
@ -222,7 +221,7 @@ int main(int argc, char ** argv) {
}
if (ok && interactive && default_k == input_kind::Lean) {
signal(SIGINT, on_ctrl_c);
lean::interactive in(env, ios, S, num_threads);
lean::interactive in(env, ios, num_threads);
in(std::cin, "[stdin]");
}
if (export_objects) {

View file

@ -8,7 +8,6 @@ Author: Leonardo de Moura
#include <string>
#include <vector>
#include <unordered_set>
#include "util/thread.h"
#include "util/lua.h"
#include "util/debug.h"
#include "util/exception.h"
@ -43,7 +42,6 @@ static char g_weak_ptr_key; // key for Lua registry (used at get_weak_ptr and sa
struct script_state::imp {
lua_State * m_state;
mutex m_mutex;
std::unordered_set<std::string> m_imported_modules;
static std::weak_ptr<imp> * get_weak_ptr(lua_State * L) {
@ -104,12 +102,10 @@ struct script_state::imp {
}
void dofile(char const * fname) {
lock_guard<mutex> lock(m_mutex);
::lean::dofile(m_state, fname);
}
void dostring(char const * str) {
lock_guard<mutex> lock(m_mutex);
::lean::dostring(m_state, str);
}
@ -166,426 +162,10 @@ bool script_state::import_explicit(char const * str) {
return m_ptr->import_explicit(str);
}
mutex & script_state::get_mutex() {
return m_ptr->m_mutex;
}
lua_State * script_state::get_state() {
return m_ptr->m_state;
}
constexpr char const * state_mt = "luastate.mt";
bool is_state(lua_State * L, int idx) {
return testudata(L, idx, state_mt);
}
script_state & to_state(lua_State * L, int idx) {
return *static_cast<script_state*>(luaL_checkudata(L, idx, state_mt));
}
int push_state(lua_State * L, script_state const & s) {
void * mem = lua_newuserdata(L, sizeof(script_state));
new (mem) script_state(s);
luaL_getmetatable(L, state_mt);
lua_setmetatable(L, -2);
return 1;
}
static int mk_state(lua_State * L) {
script_state r;
return push_state(L, r);
}
static int state_gc(lua_State * L) {
to_state(L, 1).~script_state();
return 0;
}
static int writer(lua_State *, void const * p, size_t sz, void * buf) {
buffer<char> & _buf = *static_cast<buffer<char>*>(buf);
char const * in = static_cast<char const *>(p);
for (size_t i = 0; i < sz; i++)
_buf.push_back(in[i]);
return 0;
}
struct reader_data {
buffer<char> & m_buffer;
bool m_done;
reader_data(buffer<char> & b):m_buffer(b), m_done(false) {}
};
static char const * reader(lua_State *, void * data, size_t * sz) {
reader_data & _data = *static_cast<reader_data*>(data);
if (_data.m_done) {
*sz = 0;
return nullptr;
} else {
*sz = _data.m_buffer.size();
_data.m_done = true;
return _data.m_buffer.data();
}
}
static void copy_values(lua_State * src, int first, int last, lua_State * tgt) {
for (int i = first; i <= last; i++) {
switch (lua_type(src, i)) {
case LUA_TNUMBER: lua_pushnumber(tgt, lua_tonumber(src, i)); break;
case LUA_TSTRING: lua_pushstring(tgt, lua_tostring(src, i)); break;
case LUA_TNIL: lua_pushnil(tgt); break;
case LUA_TBOOLEAN: lua_pushboolean(tgt, lua_toboolean(src, i)); break;
case LUA_TFUNCTION: {
lua_pushvalue(src, i); // copy function to the top of the stack
buffer<char> buffer;
if (lua_dump(src, writer, &buffer) != 0)
throw exception("falied to copy function between State objects");
lua_pop(src, 1); // remove function from the top of the stack
reader_data data(buffer);
if (load(tgt, reader, &data, "temporary buffer for moving functions between states") != 0)
throw exception("falied to copy function between State objects");
// copy upvalues
int j = 1;
while (true) {
char const * name = lua_getupvalue(src, i, j);
if (name == nullptr)
break;
copy_values(src, lua_gettop(src), lua_gettop(src), tgt); // copy upvalue to tgt stack
lua_pop(src, 1); // remove upvalue from src stack
lua_setupvalue(tgt, -2, j);
j++;
}
break;
}
case LUA_TUSERDATA:
if (lua_migrate_fn f = get_migrate_fn(src, i)) {
f(src, i, tgt);
} else {
throw exception("unsupported value type for inter-State call");
}
break;
default:
throw exception("unsupported value type for inter-State call");
}
}
}
int state_dostring(lua_State * L) {
return to_state(L, 1).apply([&](lua_State * S) {
char const * script = luaL_checkstring(L, 2);
int first = 3;
int last = lua_gettop(L);
int sz_before = lua_gettop(S);
int status = luaL_loadstring(S, script);
if (status)
throw script_exception(lua_tostring(S, -1));
copy_values(L, first, last, S);
pcall(S, first > last ? 0 : last - first + 1, LUA_MULTRET, 0);
int sz_after = lua_gettop(S);
if (sz_after > sz_before) {
copy_values(S, sz_before + 1, sz_after, L);
lua_pop(S, sz_after - sz_before);
}
return sz_after - sz_before;
});
}
int state_set_global(lua_State * L) {
to_state(L, 1).apply([=](lua_State * S) {
char const * name = luaL_checkstring(L, 2);
copy_values(L, 3, 3, S);
lua_setglobal(S, name);
});
return 0;
}
static int state_pred(lua_State * L) {
return push_boolean(L, is_state(L, 1));
}
static const struct luaL_Reg state_m[] = {
{"__gc", state_gc},
{"dostring", safe_function<state_dostring>},
{"eval", safe_function<state_dostring>},
{"set", safe_function<state_set_global>},
{0, 0}
};
static void open_state(lua_State * L) {
luaL_newmetatable(L, state_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, state_m, 0);
SET_GLOBAL_FUN(mk_state, "State");
SET_GLOBAL_FUN(state_pred, "is_State");
}
// TODO(Leo): allow the user to change it?
#define SMALL_DELAY 10 // in ms
chrono::milliseconds g_small_delay(SMALL_DELAY);
#if defined(LEAN_MULTI_THREAD)
/**
\brief Channel for communicating with thread objects in the Lua API
*/
class data_channel {
// We use a lua_State to implement the channel. This is quite hackish,
// but it is a convenient storage for Lua objects sent from one state to
// another.
script_state m_channel;
int m_ini;
mutex m_mutex;
condition_variable m_cv;
public:
data_channel() {
lua_State * channel = m_channel.m_ptr->m_state;
m_ini = lua_gettop(channel);
}
/**
\brief Copy elements from positions [first, last] from src stack
to the channel.
*/
void write(lua_State * src, int first, int last) {
// write the object on the top of the stack of src to the table
// on m_channel.
if (last < first)
return;
lock_guard<mutex> lock(m_mutex);
lua_State * channel = m_channel.m_ptr->m_state;
bool was_empty = lua_gettop(channel) == m_ini;
copy_values(src, first, last, channel);
if (was_empty)
m_cv.notify_one();
}
/**
\brief Retrieve one element from the channel. It will block
the execution of \c tgt if the channel is empty.
*/
int read(lua_State * tgt, int i) {
unique_lock<mutex> lock(m_mutex);
lua_State * channel = m_channel.m_ptr->m_state;
if (i > 0) {
// i is the position of the timeout argument
chrono::milliseconds dura(luaL_checkinteger(tgt, i));
if (lua_gettop(channel) == m_ini)
m_cv.wait_for(lock, dura);
if (lua_gettop(channel) == m_ini) {
// timeout...
lua_pushboolean(tgt, false);
lua_pushnil(tgt);
return 2;
} else {
lua_pushboolean(tgt, true);
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1);
return 2;
}
} else {
while (lua_gettop(channel) == m_ini) {
check_interrupted();
m_cv.wait_for(lock, g_small_delay);
}
copy_values(channel, m_ini + 1, m_ini + 1, tgt);
lua_remove(channel, m_ini + 1);
return 1;
}
}
};
/**
\brief We want the channels to be lazily created.
*/
class data_channel_ref {
std::unique_ptr<data_channel> m_channel;
mutex m_mutex;
public:
data_channel & get() {
lock_guard<mutex> lock(m_mutex);
if (!m_channel)
m_channel.reset(new data_channel());
lean_assert(m_channel);
return *m_channel;
}
};
data_channel_ref g_in_channel;
data_channel_ref g_out_channel;
int channel_read(lua_State * L) {
return g_in_channel.get().read(L, lua_gettop(L));
}
int channel_write(lua_State * L) {
g_out_channel.get().write(L, 1, lua_gettop(L));
return 0;
}
class leanlua_thread {
script_state m_state;
int m_sz_before;
std::unique_ptr<exception> m_exception;
atomic<data_channel_ref *> m_in_channel_addr;
atomic<data_channel_ref *> m_out_channel_addr;
interruptible_thread m_thread;
public:
leanlua_thread(script_state const & st, int sz_before, int num_args):
m_state(st),
m_sz_before(sz_before),
m_in_channel_addr(0),
m_out_channel_addr(0),
m_thread([=]() {
m_in_channel_addr.store(&g_in_channel);
m_out_channel_addr.store(&g_out_channel);
m_state.apply([&](lua_State * S) {
int result = lua_pcall(S, num_args, LUA_MULTRET, 0);
if (result) {
if (is_exception(S, -1))
m_exception.reset(to_exception(S, -1).clone());
else
m_exception.reset(new script_exception(lua_tostring(S, -1)));
}
});
}) {
}
~leanlua_thread() {
if (m_thread.joinable())
m_thread.join();
}
int copy_result(lua_State * src) {
if (m_exception)
m_exception->rethrow();
return m_state.apply([&](lua_State * S) {
int sz_after = lua_gettop(S);
if (sz_after > m_sz_before) {
copy_values(S, m_sz_before + 1, sz_after, src);
lua_pop(S, sz_after - m_sz_before);
}
return sz_after - m_sz_before;
});
}
void wait() {
m_thread.join();
}
void request_interrupt() {
m_thread.request_interrupt();
}
void write(lua_State * src, int first, int last) {
while (!m_in_channel_addr) {
check_interrupted();
this_thread::sleep_for(g_small_delay);
}
data_channel & in = m_in_channel_addr.load()->get();
in.write(src, first, last);
}
int read(lua_State * src) {
if (!m_out_channel_addr) {
check_interrupted();
this_thread::sleep_for(g_small_delay);
}
data_channel & out = m_out_channel_addr.load()->get();
int nargs = lua_gettop(src);
return out.read(src, nargs == 1 ? 0 : 2);
}
bool started() {
return m_in_channel_addr && m_out_channel_addr;
}
};
constexpr char const * thread_mt = "thread.mt";
bool is_thread(lua_State * L, int idx) {
return testudata(L, idx, thread_mt);
}
leanlua_thread & to_thread(lua_State * L, int idx) {
return *static_cast<leanlua_thread*>(luaL_checkudata(L, idx, thread_mt));
}
int mk_thread(lua_State * L) {
check_threadsafe();
script_state & st = to_state(L, 1);
char const * script = luaL_checkstring(L, 2);
int first = 3;
int last = lua_gettop(L);
int nargs = first > last ? 0 : last - first + 1;
int sz_before;
st.apply([&](lua_State * S) {
sz_before = lua_gettop(S);
int result = luaL_loadstring(S, script);
if (result)
throw script_exception(lua_tostring(S, -1));
copy_values(L, first, last, S);
});
void * mem = lua_newuserdata(L, sizeof(leanlua_thread));
new (mem) leanlua_thread(st, sz_before, nargs);
luaL_getmetatable(L, thread_mt);
lua_setmetatable(L, -2);
return 1;
}
static int thread_gc(lua_State * L) {
to_thread(L, 1).~leanlua_thread();
return 0;
}
static int thread_pred(lua_State * L) {
return push_boolean(L, is_thread(L, 1));
}
static int thread_write(lua_State * L) {
to_thread(L, 1).write(L, 2, lua_gettop(L));
return 0;
}
static int thread_read(lua_State * L) {
return to_thread(L, 1).read(L);
}
static int thread_interrupt(lua_State * L) {
to_thread(L, 1).request_interrupt();
return 0;
}
int thread_wait(lua_State * L) {
auto & t = to_thread(L, 1);
script_state st = to_script_state(L);
st.exec_unprotected([&]() { t.wait(); });
return t.copy_result(L);
}
static const struct luaL_Reg thread_m[] = {
{"__gc", thread_gc},
{"wait", safe_function<thread_wait>},
{"join", safe_function<thread_wait>},
{"interrupt", safe_function<thread_interrupt>},
{"write", safe_function<thread_write>},
{"read", safe_function<thread_read>},
{0, 0}
};
static void open_thread(lua_State * L) {
luaL_newmetatable(L, thread_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, thread_m, 0);
SET_GLOBAL_FUN(mk_thread, "thread");
SET_GLOBAL_FUN(thread_pred, "is_thread");
}
#endif
static int check_interrupted(lua_State *) { // NOLINT
check_interrupted();
return 0;
@ -611,27 +191,14 @@ static int yield(lua_State * L) {
static int import(lua_State * L) {
std::string fname = luaL_checkstring(L, 1);
script_state s = to_script_state(L);
s.exec_unprotected([&]() { s.import(fname.c_str()); });
s.import(fname.c_str());
return 0;
}
static void open_interrupt(lua_State * L) {
void open_extra(lua_State * L) {
SET_GLOBAL_FUN(check_interrupted, "check_interrupted");
SET_GLOBAL_FUN(sleep, "sleep");
SET_GLOBAL_FUN(yield, "yield");
#if defined(LEAN_MULTI_THREAD)
SET_GLOBAL_FUN(channel_read, "read");
SET_GLOBAL_FUN(channel_write, "write");
#endif
}
void open_extra(lua_State * L) {
open_state(L);
#if defined(LEAN_MULTI_THREAD)
open_thread(L);
#endif
open_interrupt(L);
SET_GLOBAL_FUN(import, "import");
SET_GLOBAL_FUN(import, "import");
}
}

View file

@ -7,8 +7,6 @@ Author: Leonardo de Moura
#pragma once
#include <memory>
#include <lua.hpp>
#include "util/thread.h"
#include "util/unlock_guard.h"
namespace lean {
/**
@ -20,8 +18,6 @@ public:
private:
std::shared_ptr<imp> m_ptr;
friend script_state to_script_state(lua_State * L);
mutex & get_mutex();
lua_State * get_state();
friend class data_channel;
public:
static void set_check_interrupt_freq(unsigned count);
@ -55,34 +51,10 @@ public:
*/
bool import_explicit(char const * fname);
/**
\brief Execute \c f in the using the internal Lua State.
*/
template<typename F>
typename std::result_of<F(lua_State * L)>::type apply(F && f) {
lock_guard<mutex> lock(get_mutex());
return f(get_state());
}
lua_State * get_state();
typedef void (*reg_fn)(lua_State *); // NOLINT
static void register_module(reg_fn f);
/**
\brief Auxiliary function for writing API bindings
that release the lock to this object while executing
\c f.
*/
template<typename F>
void exec_unprotected(F && f) {
unlock_guard unlock(get_mutex());
f();
}
template<typename F>
void exec_protected(F && f) {
lock_guard<mutex> lock(get_mutex());
f();
}
};
/**
\brief Return a reference to the script_state object that is wrapping \c L.

View file

@ -1,5 +0,0 @@
s = State()
s:eval([[ y = 5; print(x); print(y); ]])
s:set("x", 10)
s:set("y", nil)
s:eval([[ print(x); print(y); ]])

View file

@ -1,11 +0,0 @@
-- Execute f, and make sure is throws an error
function check_error(f)
ok, msg = pcall(function ()
f()
end)
if ok then
error("unexpected success...")
else
print("caught expected error: ", msg)
end
end

View file

@ -1,11 +0,0 @@
S = State()
T = thread(S, [[
sleep(10000)
]])
T:interrupt()
local ok, msg = pcall(function() T:wait() end)
assert(not ok)
assert(is_exception(msg))
print(msg:what():find("interrupted"))

View file

@ -1,12 +0,0 @@
f = Const("f")
a = Const("a")
S = State()
T = thread(S, [[
t = ...
g = Const("g")
return g(t)
]], f(a))
r = T:wait()
print(r)

View file

@ -1,33 +0,0 @@
S1 = State()
S2 = State()
code = [[
function f(env, prefix, num, type)
local r = list_certified_declaration()
for i = 1, num do
r = list_certified_declaration(type_check(env, mk_var_decl(prefix .. "_" .. i, type)), r)
end
return r
end
]]
S1:dostring(code)
S2:dostring(code)
local e = bare_environment()
e = add_decl(e, mk_var_decl("N", Type))
code2 = [[
e, prefix = ...
return f(e, prefix, 1000, Const("N"))
]]
T1 = thread(S1, code2, e, "x")
T2 = thread(S2, code2, e, "y")
local r1 = T1:wait()
local r2 = T2:wait()
while not r1:is_nil() do
e = e:add(r1:car())
r1 = r1:cdr()
end
while not r2:is_nil() do
e = e:add(r2:car())
r2 = r2:cdr()
end
assert(e:find("x_" .. 10))
assert(e:find("y_" .. 100))

View file

@ -1,12 +0,0 @@
S1 = State()
S2 = State()
code = [[
id = ...
for i = 1, 10000 do
print("id: " .. id .. ", val: " .. i)
end
]]
T1 = thread(S1, code, 1)
T2 = thread(S2, code, 2)
T1:wait()
T2:wait()

View file

@ -1,58 +0,0 @@
-- Create a nested lua_State object
S = State()
-- Remarks:
-- '[[ ... ]]' is a multi-line string in Lua
-- obj:method(args) is the syntax for invoking a method
-- it is actually syntax sugar for
-- obj.method(obj, args)
S:dostring([[
x = 10
]])
-- Variable x is not visible in the main State object
print(x) -- it will print nil
S:dostring([[
print(x)
]])
-- Remark: '...' is a reference to varargs in Lua
-- We can pass arguments to/from a nested state
-- The following statement passes 10 and 20 as arguments
-- to the nested lua_State object S.
-- The values returned by the script are stored in r1 and r2
r1, r2 = S:dostring([[
-- extract arguments passed to dostring
a1, a2 = ...
return a1 + a2, a1 - a2
]], 10, 20)
print("r1:", r1)
print("r2:", r2)
-- We can communicate integers, strings and Lean objects
f = Const("f")
a = Const("a")
T = S:dostring([[
t = ...
g = Const("g")
return g(g(g(t)))
]], f(a))
print(T)
-- We can also execute commands in a separate thread.
-- The following command creates a thread for running
-- the given script in the state S.
-- It does not wait the thread to finish.
T = thread(S, [[
t = ...
g = Const("g")
b = Const("b")
return g(b, t)
]], f(a))
-- The method wait makes us wait for the thread T.
-- It return the values returned by the script.
r = T:wait()
-- It will print the Lean expression g(b, f(a))
print(r)

View file

@ -1,34 +0,0 @@
S = State()
assert(is_State(S))
S:eval([[ local x = ...; assert(x == 10) ]], 10)
S:eval([[ local f = ...; assert(f) ]], true)
S:eval([[ local f = ...; assert(not f) ]], false)
S:eval([[ local s = ...; assert(s == "foo") ]], "foo")
S:eval([[ local f = ...; assert(f(1) == 3) ]], function (x) return x + 2 end)
local val = 10
S:eval([[ local f = ...; assert(f(1) == 11) ]], function (x) return x + val end)
S:eval([[ local o = ...; assert(o == name("a")) ]], name("a"))
S:eval([[ local o = ...; assert(o == Const("a")) ]], Const("a"))
S:eval([[ local o = ...; assert(is_environment(o)) ]], bare_environment())
S:eval([[ local o = ...; assert(o == mpz(100)) ]], mpz(100))
S:eval([[ local o = ...; assert(o == mpq(100)/3) ]], mpq(100)/3)
S:eval([[ local o = ...; assert(is_options(o)) ]], options())
S:eval([[ local o = ...; assert(is_sexpr(o)) ]], sexpr())
S:eval([[ local o = ...; assert(o:is_cons()) ]], sexpr(1, 2))
S:eval([[ local o = ...; assert(is_format(o)) ]], format("1"))
S:eval([[ local o1, o2, o3 = ...; assert(is_sexpr(o1)); assert(is_name(o2)); assert(o3 == 10) ]], sexpr(), name("foo"), 10)
assert(not pcall(function() S:eval([[ x = ]]) end))
local T = thread(S, [[ local x = ...; return x + 10, x - 10 ]], 10)
assert(is_thread(T))
local r1, r2 = T:wait()
assert(r1 == 20)
assert(r2 == 0)
assert(not pcall(function() S:eval([[ x = ]]) end))
local T2 = thread(S, [[ local x = ...; error("failed") ]], 10)
local ok, msg = pcall(function() T2:wait() end)
assert(not ok)
assert(is_exception(msg))
print(msg:what())
assert(msg:what():find("failed"))
local T3 = thread(S, [[ local x = ...; return x + 10, x - 10 ]], 10)
T3 = nil

View file

@ -1,10 +0,0 @@
S = State()
T = thread(S, [[
while true do
check_interrupted()
end
]])
sleep(100)
T:interrupt()
assert(not pcall(function() T:wait() end))

View file

@ -1,37 +0,0 @@
S = State()
T = thread(S, [[
print("starting thread...")
pcall(function()
while true do
check_interrupted() -- check if thread was interrupted
local ok, val = read(10) -- 10 ms timeout
if ok then
print("thread received:", val)
write(val + 10, val - 10) -- send result back to main thread
end
end
end)
]])
for i = 1, 10 do
T:write(10 * i)
local r1 = T:read()
local r2 = T:read()
print("main received: ", r1, r2)
end
T:interrupt()
print("done")
-- Channels are quite flexible, we can send closure over the channel
T = thread(S, [[
for i = 1, 10 do
local f = read()
-- send back the result of f(i)
write(f(i))
end
]])
for i = 1, 10 do
T:write(function (x) return x + i end)
print(T:read())
end