diff --git a/src/bindings/lua/name.cpp b/src/bindings/lua/name.cpp index a5d63220b..c7cf26f64 100644 --- a/src/bindings/lua/name.cpp +++ b/src/bindings/lua/name.cpp @@ -17,7 +17,11 @@ static int name_eq(lua_State * L); static int name_lt(lua_State * L); static const struct luaL_Reg name_m[] = { - {"__gc", name_gc}, {"__tostring", name_tostring}, {"__eq", name_eq}, {"__lt", name_lt}, {0, 0} + {"__gc", name_gc}, // never throws + {"__tostring", safe_function}, + {"__eq", safe_function}, + {"__lt", safe_function}, + {0, 0} }; static int mk_name(lua_State * L) { @@ -84,7 +88,7 @@ void init_name(lua_State * L) { luaL_newmetatable(L, "name.mt"); setfuncs(L, name_m, 0); - lua_pushcfunction(L, mk_name); + lua_pushcfunction(L, safe_function); lua_setglobal(L, "name"); } } diff --git a/src/bindings/lua/numerics.cpp b/src/bindings/lua/numerics.cpp index c19c50d1a..3f06fd1ae 100644 --- a/src/bindings/lua/numerics.cpp +++ b/src/bindings/lua/numerics.cpp @@ -102,17 +102,23 @@ public: luaL_newmetatable(L, M); setfuncs(L, m, 0); - lua_pushcfunction(L, mk); + lua_pushcfunction(L, safe_function); lua_setglobal(L, N); } }; template const struct luaL_Reg num_bindings::m[] = { - {"__gc", num_bindings::gc}, {"__tostring", num_bindings::tostring}, {"__eq", num_bindings::eq}, - {"__lt", num_bindings::lt}, {"__add", num_bindings::add}, {"__add", num_bindings::sub}, - {"__mul", num_bindings::mul}, {"__div", num_bindings::div}, {"__pow", num_bindings::power}, - {"__unm", num_bindings::umn}, + {"__gc", num_bindings::gc}, // never throws + {"__tostring", safe_function::tostring>}, + {"__eq", safe_function::eq>}, + {"__lt", safe_function::lt>}, + {"__add", safe_function::add>}, + {"__add", safe_function::sub>}, + {"__mul", safe_function::mul>}, + {"__div", safe_function::div>}, + {"__pow", safe_function::power>}, + {"__unm", safe_function::umn>}, {0, 0} }; @@ -121,6 +127,7 @@ constexpr char const mpz_metatable[] = "mpz.mt"; void init_mpz(lua_State * L) { num_bindings::init(L); } + constexpr char const mpq_name[] = "mpq"; constexpr char const mpq_metatable[] = "mpq.mt"; void init_mpq(lua_State * L) { diff --git a/src/bindings/lua/util.cpp b/src/bindings/lua/util.cpp index 72665a950..c7bb74ace 100644 --- a/src/bindings/lua/util.cpp +++ b/src/bindings/lua/util.cpp @@ -6,6 +6,9 @@ Author: Leonardo de Moura */ #ifdef LEAN_USE_LUA #include +#include +#include "util/exception.h" + namespace lean { /** \brief luaL_setfuncs replacement. The function luaL_setfuncs is only available in Lua 5.2. @@ -21,5 +24,24 @@ void setfuncs(lua_State * L, luaL_Reg const * l, int nup) { } lua_pop(L, nup); // remove upvalues } + +int safe_function_wrapper(lua_State * L, lua_CFunction f){ + static thread_local std::string _error_msg; + char const * error_msg; + try { + return f(L); + } catch (exception & e) { + _error_msg = e.what(); + error_msg = _error_msg.c_str(); + } catch (std::bad_alloc &) { + error_msg = "out of memory"; + } catch (std::exception & e) { + _error_msg = e.what(); + error_msg = _error_msg.c_str(); + } catch(...) { + error_msg = "unknown error"; + } + return luaL_error(L, error_msg); +} } #endif diff --git a/src/bindings/lua/util.h b/src/bindings/lua/util.h index b48233736..d3f97b97d 100644 --- a/src/bindings/lua/util.h +++ b/src/bindings/lua/util.h @@ -8,5 +8,13 @@ Author: Leonardo de Moura #include namespace lean { void setfuncs(lua_State * L, luaL_Reg const * l, int nup); +/** + \brief Wrapper for invoking function f, and catching Lean exceptions. +*/ +int safe_function_wrapper(lua_State * L, lua_CFunction f); +template +int safe_function(lua_State * L) { + return safe_function_wrapper(L, F); +} } #endif diff --git a/src/shell/lua/leanlua.cpp b/src/shell/lua/leanlua.cpp index fce92e80b..21dde2884 100644 --- a/src/shell/lua/leanlua.cpp +++ b/src/shell/lua/leanlua.cpp @@ -9,14 +9,13 @@ Author: Leonardo de Moura #ifdef LEAN_USE_LUA #include -#include "util/exception.h" #include "bindings/lua/name.h" #include "bindings/lua/numerics.h" int main(int argc, char ** argv) { int status, result; lua_State *L; - + int exitcode = 0; L = luaL_newstate(); luaL_openlibs(L); lean::init_name(L); @@ -27,21 +26,17 @@ int main(int argc, char ** argv) { status = luaL_loadfile(L, argv[i]); if (status) { std::cerr << "Couldn't load file: " << lua_tostring(L, -1) << "\n"; - return 1; - } - try { + exitcode = 1; + } else { result = lua_pcall(L, 0, LUA_MULTRET, 0); if (result) { std::cerr << "Failed to run script: " << lua_tostring(L, -1) << "\n"; - return 1; + exitcode = 1; } - } catch (lean::exception & ex) { - std::cerr << "Lean exception when running: " << argv[i] << "\n"; - std::cerr << ex.what() << "\n"; } } lua_close(L); - return 0; + return exitcode; } #else int main() {