From ad039c8cc2a5bb398632b40e39e78654308bedc4 Mon Sep 17 00:00:00 2001 From: ThePhD Date: Thu, 22 Oct 2015 11:20:32 -0400 Subject: [PATCH] Allow for usage of function error handlers with lua TODO: make this the default mode, with a short-cutting mode (sol::no_fail_function?) --- sol/function.hpp | 102 +++++++++++++++++++++++++++++++++------------- sol/reference.hpp | 20 ++++++--- sol/stack.hpp | 22 ++++++++++ sol/types.hpp | 8 ++++ 4 files changed, 119 insertions(+), 33 deletions(-) diff --git a/sol/function.hpp b/sol/function.hpp index e55e0fe8..236a8f40 100644 --- a/sol/function.hpp +++ b/sol/function.hpp @@ -38,7 +38,7 @@ private: lua_State* L; int index; int returncount; - int error; + call_error error; template stack::get_return get(types, indices) const { @@ -53,16 +53,36 @@ private: public: function_result() = default; - function_result(lua_State* L, int index = -1, int returncount = 0, int code = LUA_OK): L(L), index(index), returncount(returncount), error(code) { + function_result(lua_State* L, int index = -1, int returncount = 0, call_error error = call_error::ok): L(L), index(index), returncount(returncount), error(error) { } function_result(const function_result&) = default; function_result& operator=(const function_result&) = default; - function_result(function_result&&) = default; - function_result& operator=(function_result&&) = default; + function_result(function_result&& o) : L(o.L), index(o.index), returncount(o.returncount), error(o.error) { + // Must be manual, otherwise destructor will screw us + // return count being 0 is enough to keep things clean + // but will be thorough + o.L = nullptr; + o.index = 0; + o.returncount = 0; + o.error = call_error::runtime; + } + function_result& operator=(function_result&& o) { + L = o.L; + index = o.index; + returncount = o.returncount; + error = o.error; + // Must be manual, otherwise destructor will screw us + // return count being 0 is enough to keep things clean + // but will be thorough + o.L = nullptr; + o.index = 0; + o.returncount = 0; + o.error = call_error::runtime; + } bool valid() const { - return error == LUA_OK; + return error == call_error::ok; } template @@ -81,23 +101,42 @@ public: } ~function_result() { - lua_pop(L, returncount); + stack::remove(L, index, error == call_error::ok ? returncount : 1); } }; class function : public reference { +public: + static reference default_handler; + private: - int luacodecall(std::ptrdiff_t argcount, std::ptrdiff_t resultcount) const { - return lua_pcallk(state(), static_cast(argcount), static_cast(resultcount), 0, 0, nullptr); + struct handler { + const reference& target; + int stack; + handler(const reference& target) : target(target), stack(0) { + if (target.valid()) { + stack = lua_gettop(target.state()) + 1; + target.push(); + } + } + ~handler() { + if (target.valid()) { + lua_remove(target.state(), stack); + } + } + }; + + int luacodecall(std::ptrdiff_t argcount, std::ptrdiff_t resultcount, const handler& h) const { + return lua_pcallk(state(), static_cast(argcount), static_cast(resultcount), h.stack, 0, nullptr); } - void luacall(std::ptrdiff_t argcount, std::ptrdiff_t resultcount) const { + void luacall(std::ptrdiff_t argcount, std::ptrdiff_t resultcount, const handler& h) const { lua_callk(state(), static_cast(argcount), static_cast(resultcount), 0, nullptr); } template - std::tuple invoke(indices, types, std::ptrdiff_t n) const { - luacall(n, sizeof...(Ret)); + std::tuple invoke(indices, types, std::ptrdiff_t n, const handler& h) const { + luacall(n, sizeof...(Ret), h); const int nreturns = static_cast(sizeof...(Ret)); const int stacksize = lua_gettop(state()); const int firstreturn = std::max(0, stacksize - nreturns) + 1; @@ -107,56 +146,60 @@ private: } template - Ret invoke(indices, types, std::ptrdiff_t n) const { - luacall(n, 1); + Ret invoke(indices, types, std::ptrdiff_t n, const handler& h) const { + luacall(n, 1, h); return stack::pop(state()); } template - void invoke(indices, types, std::ptrdiff_t n) const { - luacall(n, 0); + void invoke(indices, types, std::ptrdiff_t n, const handler& h) const { + luacall(n, 0, h); } - function_result invoke(indices<>, types<>, std::ptrdiff_t n) const { + function_result invoke(indices<>, types<>, std::ptrdiff_t n, const handler& h) const { const int stacksize = lua_gettop(state()); const int firstreturn = std::max(0, stacksize - static_cast(n) - 1); int code = LUA_OK; try { - code = luacodecall( n, LUA_MULTRET ); + code = luacodecall(n, LUA_MULTRET, h); } // Handle C++ errors thrown from C++ functions bound inside of lua - catch ( const std::exception& error ) { + catch (const std::exception& error) { code = LUA_ERRRUN; - stack::push( state(), error.what() ); + stack::push(state(), error.what()); } // TODO: handle idiots? - /*catch ( const char* error ) { + /*catch (const char* error) { code = LUA_ERRRUN; - stack::push( state(), error ); + stack::push(state(), error); } - catch ( const std::string& error ) { + catch (const std::string& error) { code = LUA_ERRRUN; - stack::push( state(), error ); + stack::push(state(), error); } - catch ( ... ) { + catch (...) { code = LUA_ERRRUN; stack::push( state(), "[sol] an unknownable runtime exception occurred" ); }*/ - catch ( ... ) { + catch (...) { throw; } const int poststacksize = lua_gettop(state()); const int returncount = poststacksize - firstreturn; - return function_result(state(), firstreturn + 1, returncount, code); + return function_result(state(), firstreturn + ( error_handler.valid() ? 0 : 1 ), returncount, static_cast(code)); } public: + sol::reference error_handler; + function() = default; function(lua_State* L, int index = -1): reference(L, index) { type_assert(L, index, type::function); } function(const function&) = default; function& operator=(const function&) = default; + function( function&& ) = default; + function& operator=( function&& ) = default; template function_result operator()(Args&&... args) const { @@ -171,14 +214,17 @@ public: template auto call(Args&&... args) const - -> decltype(invoke(types(), types(), 0)) { + -> decltype(invoke(types(), types(), 0, std::declval())) { + handler h(error_handler); push(); int pushcount = stack::push_args(state(), std::forward(args)...); auto tr = types(); - return invoke(tr, tr, pushcount); + return invoke(tr, tr, pushcount, h); } }; +sol::reference function::default_handler; + namespace stack { template struct pusher> { diff --git a/sol/reference.hpp b/sol/reference.hpp index 066b4b21..4d44636a 100644 --- a/sol/reference.hpp +++ b/sol/reference.hpp @@ -31,6 +31,8 @@ private: int ref = LUA_NOREF; int copy() const { + if (ref == LUA_NOREF) + return LUA_NOREF; push(); return luaL_ref(L, LUA_REGISTRYINDEX); } @@ -46,11 +48,6 @@ public: luaL_unref(L, LUA_REGISTRYINDEX, ref); } - int push() const noexcept { - lua_rawgeti(L, LUA_REGISTRYINDEX, ref); - return 1; - } - reference(reference&& o) noexcept { L = o.L; ref = o.ref; @@ -80,6 +77,19 @@ public: return *this; } + int push() const noexcept { + lua_rawgeti(L, LUA_REGISTRYINDEX, ref); + return 1; + } + + int get_index() const { + return ref; + } + + bool valid () const { + return !(ref == LUA_NOREF); + } + type get_type() const { push(); int result = lua_type(L, -1); diff --git a/sol/stack.hpp b/sol/stack.hpp index 138bc851..073d24bb 100644 --- a/sol/stack.hpp +++ b/sol/stack.hpp @@ -575,6 +575,28 @@ inline void call(lua_State* L, int start, indices, types, types(count) ); + return; + } + + // Remove each item one at a time using stack operations + // Probably slower, maybe, haven't benchmarked, + // but necessary + if ( index < 0 ) { + index = lua_gettop( L ) + (index + 1); + } + int last = index + count; + for ( int i = index; i < last; ++i ) { + lua_remove( L, i ); + } +} + template ::value>::type> inline R call(lua_State* L, int start, types tr, types ta, Fx&& fx, FxArgs&&... args) { return detail::call(L, start, ta, tr, ta, std::forward(fx), std::forward(args)...); diff --git a/sol/types.hpp b/sol/types.hpp index 9bbb350a..574e134e 100644 --- a/sol/types.hpp +++ b/sol/types.hpp @@ -62,6 +62,14 @@ enum class call_syntax { colon = 1 }; +enum class call_error : int { + ok = LUA_OK, + runtime = LUA_ERRRUN, + memory = LUA_ERRMEM, + handler = LUA_ERRERR, + gc = LUA_ERRGCMM +}; + enum class type : int { none = LUA_TNONE, nil = LUA_TNIL,