diff --git a/sol/function.hpp b/sol/function.hpp index 021bca2d..0a9d1f10 100644 --- a/sol/function.hpp +++ b/sol/function.hpp @@ -120,22 +120,22 @@ private: } } ~handler() { - if (target.valid()) { + if (stack > 0) { lua_remove(target.state(), stack); } } }; - int luacodecall(std::ptrdiff_t argcount, std::ptrdiff_t resultcount, const handler& h) const { + int luacodecall(std::ptrdiff_t argcount, std::ptrdiff_t resultcount, handler& h) const { return lua_pcallk(state(), static_cast(argcount), static_cast(resultcount), h.stack, 0, nullptr); } - void luacall(std::ptrdiff_t argcount, std::ptrdiff_t resultcount, const handler& h) const { + void luacall(std::ptrdiff_t argcount, std::ptrdiff_t resultcount, handler& h) const { lua_callk(state(), static_cast(argcount), static_cast(resultcount), 0, nullptr); } template - std::tuple invoke(indices, types, std::ptrdiff_t n, const handler& h) const { + std::tuple invoke(indices, types, std::ptrdiff_t n, handler& h) const { luacall(n, sizeof...(Ret), h); const int nreturns = static_cast(sizeof...(Ret)); const int stacksize = lua_gettop(state()); @@ -146,17 +146,18 @@ private: } template - Ret invoke(indices, types, std::ptrdiff_t n, const handler& h) const { + Ret invoke(indices, types, std::ptrdiff_t n, handler& h) const { luacall(n, 1, h); return stack::pop(state()); } template - void invoke(indices, types, std::ptrdiff_t n, const handler& h) const { + void invoke(indices, types, std::ptrdiff_t n, handler& h) const { luacall(n, 0, h); } - function_result invoke(indices<>, types<>, std::ptrdiff_t n, const handler& h) const { + function_result invoke(indices<>, types<>, std::ptrdiff_t n, handler& h) const { + const bool handlerpushed = error_handler.valid(); const int stacksize = lua_gettop(state()); const int firstreturn = std::max(0, stacksize - static_cast(n) - 1); int code = LUA_OK; @@ -166,6 +167,7 @@ private: // Handle C++ errors thrown from C++ functions bound inside of lua catch (const std::exception& error) { code = LUA_ERRRUN; + h.stack = 0; stack::push(state(), error.what()); } // TODO: handle idiots? @@ -186,7 +188,7 @@ private: } const int poststacksize = lua_gettop(state()); const int returncount = poststacksize - firstreturn; - return function_result(state(), firstreturn + ( error_handler.valid() ? 0 : 1 ), returncount, static_cast(code)); + return function_result(state(), firstreturn + ( handlerpushed ? 0 : 1 ), returncount, static_cast(code)); } public: diff --git a/tests.cpp b/tests.cpp index 2b8f8648..312037f5 100644 --- a/tests.cpp +++ b/tests.cpp @@ -998,3 +998,52 @@ TEST_CASE("interop/null-to-nil-and-back", "nil should be the given type when a p "rofl(x)\n" "assert(x == nil)")); } + +TEST_CASE( "functions/sol::function_result", "Function result should be the beefy return type for sol::function that allows for error checking and error handlers" ) { + sol::state lua; + lua.open_libraries( sol::lib::base, sol::lib::debug ); + + static const char errormessage1[] = "true error message"; + static const char errormessage2[] = "doodle"; + + // Some function; just using a lambda to be cheap + auto doom = []() { + // Bypasses handler function: puts information directly into lua error + throw std::exception( errormessage1 ); + }; + auto luadoom = [&lua]() { + // Does not bypass error function, will call it + luaL_error( lua.lua_state(), "BIG ERROR MESSAGES!" ); + }; + auto specialhandler = []( std::string message ) { + return errormessage2; + }; + + lua.set_function( "doom", doom ); + lua.set_function( "luadoom", luadoom ); + lua.set_function( "cpphandler", specialhandler ); + lua.script( + std::string( "function handler ( message )" ) + + " return '" + errormessage2 + "'" + + "end" + ); + + sol::function func = lua[ "doom" ]; + sol::function luafunc = lua[ "luadoom" ]; + sol::function luahandler = lua[ "handler" ]; + sol::function cpphandler = lua[ "cpphandler" ]; + func.error_handler = luahandler; + luafunc.error_handler = cpphandler; + + sol::function_result result1 = func(); + int test = lua_gettop(lua.lua_state()); + REQUIRE(!result1.valid()); + std::string errorstring = result1; + REQUIRE(errorstring == errormessage1); + + sol::function_result result2 = luafunc(); + REQUIRE(!result2.valid()); + errorstring = result2; + REQUIRE(errorstring == errormessage2); + +}