diff --git a/src/util/script_state.cpp b/src/util/script_state.cpp index c345ab867..31fa91094 100644 --- a/src/util/script_state.cpp +++ b/src/util/script_state.cpp @@ -40,8 +40,8 @@ void open_extra(lua_State * L); static char g_weak_ptr_key; // key for Lua registry (used at get_weak_ptr and save_weak_ptr) struct script_state::imp { - lua_State * m_state; - recursive_mutex m_mutex; + lua_State * m_state; + mutex m_mutex; std::unordered_set m_imported_modules; static std::weak_ptr * get_weak_ptr(lua_State * L) { @@ -100,12 +100,12 @@ struct script_state::imp { } void dofile(char const * fname) { - lock_guard lock(m_mutex); + lock_guard lock(m_mutex); ::lean::dofile(m_state, fname); } void dostring(char const * str) { - lock_guard lock(m_mutex); + lock_guard lock(m_mutex); ::lean::dostring(m_state, str); } @@ -162,7 +162,7 @@ bool script_state::import_explicit(char const * str) { return m_ptr->import_explicit(str); } -recursive_mutex & script_state::get_mutex() { +mutex & script_state::get_mutex() { return m_ptr->m_mutex; } @@ -454,8 +454,7 @@ public: m_thread.join(); } - int wait(lua_State * src) { - m_thread.join(); + int copy_result(lua_State * src) { if (m_exception) m_exception->rethrow(); return m_state.apply([&](lua_State * S) { @@ -468,6 +467,10 @@ public: }); } + void wait() { + m_thread.join(); + } + void request_interrupt() { m_thread.request_interrupt(); } @@ -552,7 +555,10 @@ static int thread_interrupt(lua_State * L) { } int thread_wait(lua_State * L) { - return to_thread(L, 1).wait(L); + auto & t = to_thread(L, 1); + script_state st = to_script_state(L); + st.exec_unprotected([&]() { t.wait(); }); + return t.copy_result(L); } static const struct luaL_Reg thread_m[] = { diff --git a/src/util/script_state.h b/src/util/script_state.h index ff1a1d2ba..d4a931288 100644 --- a/src/util/script_state.h +++ b/src/util/script_state.h @@ -20,7 +20,7 @@ public: private: std::shared_ptr m_ptr; friend script_state to_script_state(lua_State * L); - recursive_mutex & get_mutex(); + mutex & get_mutex(); lua_State * get_state(); friend class data_channel; public: @@ -60,7 +60,7 @@ public: */ template typename std::result_of::type apply(F && f) { - lock_guard lock(get_mutex()); + lock_guard lock(get_mutex()); return f(get_state()); } @@ -74,13 +74,13 @@ public: */ template void exec_unprotected(F && f) { - unlock_guard unlock(get_mutex()); + unlock_guard unlock(get_mutex()); f(); } template void exec_protected(F && f) { - lock_guard lock(get_mutex()); + lock_guard lock(get_mutex()); f(); } }; diff --git a/src/util/unlock_guard.h b/src/util/unlock_guard.h index 1ed621974..65cdccd4c 100644 --- a/src/util/unlock_guard.h +++ b/src/util/unlock_guard.h @@ -29,11 +29,10 @@ namespace lean { \warning The calling thread must own the lock to m_mutex */ -template class unlock_guard { - Mutex & m_mutex; + mutex & m_mutex; public: - explicit unlock_guard(Mutex & m):m_mutex(m) { m_mutex.unlock(); } + explicit unlock_guard(mutex & m):m_mutex(m) { m_mutex.unlock(); } unlock_guard(unlock_guard const &) = delete; unlock_guard(unlock_guard &&) = delete; unlock_guard & operator=(unlock_guard const &) = delete;