diff --git a/src/library/tactic/tactic.cpp b/src/library/tactic/tactic.cpp index 6ceb064f2..d97834399 100644 --- a/src/library/tactic/tactic.cpp +++ b/src/library/tactic/tactic.cpp @@ -386,20 +386,41 @@ static int mk_lua_tactic01(lua_State * L) { luaref ref(L, 1); return push_tactic(L, mk_tactic01([=](environment const & env, io_state const & ios, proof_state const & s) -> optional { - optional r; script_state _S(S); - _S.exec_protected([&]() { - ref.push(); // push user-fun on the stack - push_environment(L, env); - push_io_state(L, ios); - push_proof_state(L, s); - pcall(L, 3, 1, 0); - if (is_proof_state(L, -1)) { - r = to_proof_state(L, -1); - } - lua_pop(L, 1); - }); - return r; + optional r; + luaref coref; // Remark: we have to release the reference in a protected block. + try { + bool done = false; + lua_State * co; + _S.exec_protected([&]() { + co = lua_newthread(L); // create a coroutine for executing user-fun + coref = luaref(L, -1); // make sure co-routine in not deleted + lua_pop(L, 1); + ref.push(); // push user-fun on the stack + push_environment(L, env); // push args... + push_io_state(L, ios); + push_proof_state(L, s); + lua_xmove(L, co, 4); // move function and arguments to co + done = resume(co, 3); + }); + while (!done) { + check_interrupted(); + std::this_thread::yield(); // give another thread a chance to execute + _S.exec_protected([&]() { + done = resume(co, 0); + }); + } + _S.exec_protected([&]() { + if (is_proof_state(co, -1)) { + r = to_proof_state(co, -1); + } + coref.release(); + }); + return r; + } catch (...) { + _S.exec_protected([&]() { coref.release(); }); + throw; + } })); } diff --git a/src/util/lua.cpp b/src/util/lua.cpp index 9468fc7df..267d18099 100644 --- a/src/util/lua.cpp +++ b/src/util/lua.cpp @@ -115,6 +115,21 @@ void pcall(lua_State * L, int nargs, int nresults, int errorfun) { check_result(L, result); } +bool resume(lua_State * L, int nargs) { + #if LUA_VERSION_NUM < 502 + int result = lua_resume(L, nargs); + #else + int result = lua_resume(L, nullptr, nargs); + #endif + if (result == LUA_YIELD) + return false; + if (result == 0) + return true; + check_result(L, result); + lean_unreachable(); + return true; +} + /** \brief Wrapper for "customers" that are only using a subset of Lean libraries. diff --git a/src/util/lua.h b/src/util/lua.h index 43eefc2f8..fa9332dd7 100644 --- a/src/util/lua.h +++ b/src/util/lua.h @@ -20,6 +20,11 @@ size_t objlen(lua_State * L, int idx); void dofile(lua_State * L, char const * fname); void dostring(lua_State * L, char const * str); void pcall(lua_State * L, int nargs, int nresults, int errorfun); +/** + \brief Return true iff coroutine is done, false if it has yielded, + and throws an exception if error. +*/ +bool resume(lua_State * L, int nargs); int lessthan(lua_State * L, int idx1, int idx2); int equal(lua_State * L, int idx1, int idx2); int get_nonnil_top(lua_State * L); diff --git a/src/util/luaref.cpp b/src/util/luaref.cpp index abd2dc82e..d88053fc6 100644 --- a/src/util/luaref.cpp +++ b/src/util/luaref.cpp @@ -34,6 +34,13 @@ luaref::~luaref() { luaL_unref(m_state, LUA_REGISTRYINDEX, m_ref); } +void luaref::release() { + if (m_state) { + luaL_unref(m_state, LUA_REGISTRYINDEX, m_ref); + m_state = nullptr; + } +} + luaref & luaref::operator=(luaref const & r) { if (m_ref == r.m_ref) return *this; diff --git a/src/util/luaref.h b/src/util/luaref.h index d4abd696e..6b9eaab84 100644 --- a/src/util/luaref.h +++ b/src/util/luaref.h @@ -23,6 +23,7 @@ public: luaref(luaref const & r); luaref(luaref && r); ~luaref(); + void release(); luaref & operator=(luaref const & r); void push() const; lua_State * get_state() const { return m_state; } diff --git a/tests/lua/threads/tactic2.lua b/tests/lua/threads/tactic2.lua new file mode 100644 index 000000000..93b023018 --- /dev/null +++ b/tests/lua/threads/tactic2.lua @@ -0,0 +1,32 @@ +local env = environment() +local ios = io_state() +local Bool = Const("Bool") +env:add_var("p", Bool) +env:add_var("q", Bool) +local p, q = Consts("p, q") +local ctx = context() + +S = State() +-- tactics t1 and t2 uses yield to implement cooperative +-- multitasking +local counter1 = 0 +local t1 = tactic(function(env, ios, s) + while true do + counter1 = counter1 + 1 + coroutine.yield() + end +end) + +local counter2 = 0 +local t2 = tactic(function(env, ios, s) + while true do + counter2 = counter2 + 1 + coroutine.yield() + end + end) + +local T = (t1:par(t2)):try_for(150) +T:solve(env, ios, ctx, p) +print(counter1, counter2) +assert(counter1 > 2) +assert(counter2 > 2)