From d42efd7fdf172eb4e06e2c4e059f43f11fd9a6e8 Mon Sep 17 00:00:00 2001 From: ThePhD Date: Sat, 27 Feb 2016 07:56:28 -0500 Subject: [PATCH] Proper trampolines to allow luajit to play nice with all the other kids in `sol`. --- bootstrap.py | 5 +- sol/function.hpp | 10 +++- sol/function_types_allocator.hpp | 12 ++--- sol/function_types_core.hpp | 13 +++--- sol/function_types_overload.hpp | 5 +- sol/function_types_static.hpp | 16 +++++-- sol/function_types_usertype.hpp | 6 +-- sol/stack.hpp | 20 +------- sol/state.hpp | 1 - sol/types.hpp | 40 ++++++++++++++++ tests.cpp | 79 ++++++++++++++++---------------- 11 files changed, 123 insertions(+), 84 deletions(-) diff --git a/bootstrap.py b/bootstrap.py index 14acc66e..acc4fa68 100755 --- a/bootstrap.py +++ b/bootstrap.py @@ -125,7 +125,10 @@ if 'linux' in sys.platform: builddir = 'bin' objdir = 'obj' -tests = os.path.join(builddir, 'tests') +if 'win32' in sys.platform: + tests = os.path.join(builddir, 'tests.exe') +else: + tests = os.path.join(builddir, 'tests') # ninja file ninja = ninja_syntax.Writer(open('build.ninja', 'w')) diff --git a/sol/function.hpp b/sol/function.hpp index b72b7e79..a1b8f925 100644 --- a/sol/function.hpp +++ b/sol/function.hpp @@ -170,6 +170,12 @@ private: returncount = poststacksize - firstreturn; } // Handle C++ errors thrown from C++ functions bound inside of lua + catch (const char* error) { + h.stackindex = 0; + stack::push(lua_state(), error); + firstreturn = lua_gettop(lua_state()); + return protected_function_result(lua_state(), firstreturn, 0, 1, call_status::runtime); + } catch (const std::exception& error) { h.stackindex = 0; stack::push(lua_state(), error.what()); @@ -296,7 +302,7 @@ struct pusher> { dFx memfxptr(std::forward(fx)); auto userptr = detail::ptr(obj); void* userobjdata = static_cast(userptr); - lua_CFunction freefunc = &function_detail::static_member_function, uFx>::call; + lua_CFunction freefunc = &function_detail::upvalue_member_function, uFx>::call; int upvalues = stack::stack_detail::push_as_upvalues(L, memfxptr); upvalues += stack::push(L, userobjdata); @@ -307,7 +313,7 @@ struct pusher> { template static void set_fx(std::false_type, lua_State* L, Fx&& fx) { std::decay_t target(std::forward(fx)); - lua_CFunction freefunc = &function_detail::static_function::call; + lua_CFunction freefunc = &function_detail::upvalue_free_function::call; int upvalues = stack::stack_detail::push_as_upvalues(L, target); stack::push(L, freefunc, upvalues); diff --git a/sol/function_types_allocator.hpp b/sol/function_types_allocator.hpp index 275c4250..39f9e6c6 100644 --- a/sol/function_types_allocator.hpp +++ b/sol/function_types_allocator.hpp @@ -75,9 +75,9 @@ inline int construct(lua_State* L) { luaL_getmetatable(L, &meta[0]); if (stack::get(L) == type::nil) { lua_pop(L, 1); - std::string err = "unable to get usertype metatable for "; - err += meta; - throw error(err); + std::string err = "sol: unable to get usertype metatable for "; + err += usertype_traits::name; + return luaL_error(L, err.c_str()); } lua_setmetatable(L, -2); @@ -126,9 +126,9 @@ struct usertype_constructor_function : base_function { luaL_getmetatable(L, &meta[0]); if (stack::get(L) == type::nil) { lua_pop(L, 1); - std::string err = "unable to get usertype metatable for "; - err += meta; - throw error(err); + std::string err = "sol: unable to get usertype metatable for "; + err += usertype_traits::name; + return luaL_error(L, err.c_str()); } lua_setmetatable(L, -2); diff --git a/sol/function_types_core.hpp b/sol/function_types_core.hpp index ea437553..d26da485 100644 --- a/sol/function_types_core.hpp +++ b/sol/function_types_core.hpp @@ -169,8 +169,8 @@ public: }; struct base_function { - virtual int operator()(lua_State*) { - throw error("failure to call specialized wrapped C++ function from Lua"); + virtual int operator()(lua_State* L) { + return luaL_error(L, "sol: failure to call specialized wrapped C++ function from Lua"); } virtual ~base_function() {} @@ -178,18 +178,17 @@ struct base_function { static int base_call(lua_State* L, void* inheritancedata) { if (inheritancedata == nullptr) { - throw error("call from Lua to C++ function has null data"); + return luaL_error(L, "sol: call from Lua to C++ function has null data"); } base_function* pfx = static_cast(inheritancedata); base_function& fx = *pfx; - int r = fx(L); - return r; + return detail::trampoline(L, fx); } -static int base_gc(lua_State*, void* udata) { +static int base_gc(lua_State* L, void* udata) { if (udata == nullptr) { - throw error("call from lua to C++ gc function with null data"); + return luaL_error(L, "sol: call from lua to C++ gc function with null data"); } base_function* ptr = static_cast(udata); diff --git a/sol/function_types_overload.hpp b/sol/function_types_overload.hpp index df0cc34a..7fcf564f 100644 --- a/sol/function_types_overload.hpp +++ b/sol/function_types_overload.hpp @@ -40,8 +40,8 @@ struct overload_traits> { }; template -inline int overload_match_arity(types<>, std::index_sequence<>, std::index_sequence, Match&&, lua_State*, int, int, Args&&...) { - throw error("no matching function call takes this number of arguments and the specified types"); +inline int overload_match_arity(types<>, std::index_sequence<>, std::index_sequence, Match&&, lua_State* L, int, int, Args&&...) { + return luaL_error(L, "sol: no matching function call takes this number of arguments and the specified types"); } template @@ -124,7 +124,6 @@ struct usertype_overloaded_function : base_function { auto mfx = [&](auto&&... args){ return this->call(std::forward(args)...); }; return overload_match>>...>(mfx, L, 2); } - }; } // function_detail } // sol diff --git a/sol/function_types_static.hpp b/sol/function_types_static.hpp index abb87003..1af19f07 100644 --- a/sol/function_types_static.hpp +++ b/sol/function_types_static.hpp @@ -27,28 +27,32 @@ namespace sol { namespace function_detail { template -struct static_function { +struct upvalue_free_function { typedef std::remove_pointer_t> function_type; typedef meta::function_traits traits_type; - static int call(lua_State* L) { + static int real_call(lua_State* L) { auto udata = stack::stack_detail::get_as_upvalues(L); function_type* fx = udata.first; int r = stack::call_into_lua(meta::tuple_types(), typename traits_type::args_type(), fx, L, 1); return r; } + static int call (lua_State* L) { + return detail::static_trampoline<&real_call>(L); + } + int operator()(lua_State* L) { return call(L); } }; template -struct static_member_function { +struct upvalue_member_function { typedef std::remove_pointer_t> function_type; typedef meta::function_traits traits_type; - static int call(lua_State* L) { + static int real_call(lua_State* L) { auto memberdata = stack::stack_detail::get_as_upvalues(L, 1); auto objdata = stack::stack_detail::get_as_upvalues(L, memberdata.second); function_type& memfx = memberdata.first; @@ -57,6 +61,10 @@ struct static_member_function { return stack::call_into_lua(meta::tuple_types(), typename traits_type::args_type(), fx, L, 1); } + static int call (lua_State* L) { + return detail::static_trampoline<&real_call>(L); + } + int operator()(lua_State* L) { return call(L); } diff --git a/sol/function_types_usertype.hpp b/sol/function_types_usertype.hpp index 0bf610bc..84d0f914 100644 --- a/sol/function_types_usertype.hpp +++ b/sol/function_types_usertype.hpp @@ -95,7 +95,7 @@ struct usertype_function : public usertype_function_core { int prelude(lua_State* L) { this->fx.item = detail::ptr(stack::get(L, 1)); if(this->fx.item == nullptr) { - throw error("userdata for function call is null: are you using the wrong syntax? (use item:function/variable(...) syntax)"); + return luaL_error(L, "sol: userdata for function call is null: are you using the wrong syntax? (use item:function/variable(...) syntax)"); } return static_cast(*this)(meta::tuple_types(), args_type(), Index<2>(), L); } @@ -120,7 +120,7 @@ struct usertype_variable_function : public usertype_function_core int argcount = lua_gettop(L); this->fx.item = stack::get(L, 1); if(this->fx.item == nullptr) { - throw error("userdata for member variable is null"); + return luaL_error(L, "sol: userdata for member variable is null"); } switch(argcount) { case 2: @@ -128,7 +128,7 @@ struct usertype_variable_function : public usertype_function_core case 3: return static_cast(*this)(meta::tuple_types(), args_type(), Index<3>(), L); default: - throw error("cannot get/set userdata member variable with inappropriate number of arguments"); + return luaL_error(L, "sol: cannot get/set userdata member variable with inappropriate number of arguments"); } } diff --git a/sol/stack.hpp b/sol/stack.hpp index 517cd0df..6ce25537 100644 --- a/sol/stack.hpp +++ b/sol/stack.hpp @@ -315,7 +315,7 @@ template<> struct getter { static nil_t get(lua_State* L, int index = -1) { if(lua_isnil(L, index) == 0) { - throw sol::error("not nil"); + throw error("not nil"); } return nil_t{ }; } @@ -796,22 +796,6 @@ inline void call(types, types ta, std::index_sequence tai, check_arguments{}.check(ta, tai, L, start); fx(std::forward(args)..., stack::get(L, start + I)...); } - -inline int luajit_exception_jump (lua_State* L, lua_CFunction func) { - try { - return func(L); - } - catch (const char *s) { // Catch and convert exceptions. - lua_pushstring(L, s); - } - catch (const std::exception& e) { - lua_pushstring(L, e.what()); - } - catch (...) { - lua_pushstring(L, "caught (...)"); - } - return lua_error(L); // Rethrow as a Lua error. -} } // stack_detail inline void remove( lua_State* L, int index, int count ) { @@ -895,7 +879,7 @@ inline call_syntax get_call_syntax(lua_State* L, const std::string& meta) { return call_syntax::dot; } -inline void luajit_exception_handler(lua_State* L, int(*handler)(lua_State*, lua_CFunction) = stack_detail::luajit_exception_jump) { +inline void luajit_exception_handler(lua_State* L, int(*handler)(lua_State*, lua_CFunction) = detail::c_trampoline) { #ifdef SOL_LUAJIT lua_pushlightuserdata(L, (void*)handler); luaJIT_setmode(L, -1, LUAJIT_MODE_WRAPCFUNC | LUAJIT_MODE_ON); diff --git a/sol/state.hpp b/sol/state.hpp index ca3ecd8d..a27e3121 100644 --- a/sol/state.hpp +++ b/sol/state.hpp @@ -32,7 +32,6 @@ public: state(lua_CFunction panic = detail::atpanic) : unique_base(luaL_newstate(), lua_close), state_view(unique_base::get()) { set_panic(panic); - sol::stack::luajit_exception_handler(unique_base::get()); } using state_view::get; diff --git a/sol/types.hpp b/sol/types.hpp index 8b802cb8..158ba31c 100644 --- a/sol/types.hpp +++ b/sol/types.hpp @@ -27,6 +27,46 @@ #include namespace sol { +namespace detail { + +template +inline int static_trampoline (lua_State* L) { + try { + return f(L); + } + catch (const char *s) { // Catch and convert exceptions. + lua_pushstring(L, s); + } + catch (const std::exception& e) { + lua_pushstring(L, e.what()); + } + catch (...) { + lua_pushstring(L, "caught (...) exception"); + } + return lua_error(L); +} + +template +inline int trampoline(lua_State* L, Fx&& f) { + try { + return f(L); + } + catch (const char *s) { // Catch and convert exceptions. + lua_pushstring(L, s); + } + catch (const std::exception& e) { + lua_pushstring(L, e.what()); + } + catch (...) { + lua_pushstring(L, "caught (...) exception"); + } + return lua_error(L); +} + +inline int c_trampoline(lua_State* L, lua_CFunction f) { + return trampoline(L, f); +} +} struct nil_t {}; const nil_t nil {}; inline bool operator==(nil_t, nil_t) { return true; } diff --git a/tests.cpp b/tests.cpp index 6616894b..9b3b6bf9 100644 --- a/tests.cpp +++ b/tests.cpp @@ -1165,48 +1165,53 @@ TEST_CASE("interop/null-to-nil-and-back", "nil should be the given type when a p TEST_CASE( "functions/function_result-protected_function_result", "Function result should be the beefy return type for sol::function that allows for error checking and error handlers" ) { sol::state lua; lua.open_libraries( sol::lib::base, sol::lib::debug ); - static const char errormessage1[] = "true error message"; - static const char errormessage2[] = "doodle"; + static const char unhandlederrormessage[] = "true error message"; + static const char handlederrormessage[] = "doodle"; // Some function; just using a lambda to be cheap - auto doom = []() { - // Bypasses handler function: puts information directly into lua error - throw std::runtime_error( errormessage1 ); + auto doomfx = []() { + std::cout << "doomfx called" << std::endl; + throw std::runtime_error( unhandlederrormessage ); }; - auto luadoom = [&lua]() { + auto luadoomfx = [&lua]() { + std::cout << "luadoomfx called" << std::endl; // Does not bypass error function, will call it - luaL_error( lua.lua_state(), "BIG ERROR MESSAGES!" ); - }; - auto specialhandler = []( std::string ) { - return errormessage2; + luaL_error( lua.lua_state(), unhandlederrormessage ); }; + lua.set_function("doom", doomfx); + lua.set_function("luadoom", luadoomfx); - lua.set_function( "doom", doom ); - lua.set_function( "luadoom", luadoom ); - lua.set_function( "cpphandler", specialhandler ); + auto cpphandlerfx = []( std::string x ) { + std::cout << "c++ handler called with: " << x << std::endl; + return handlederrormessage; + }; + lua.set_function( "cpphandler", cpphandlerfx ); lua.script( - std::string( "function handler ( message )" ) - + " return '" + errormessage2 + "'" + std::string( "function luahandler ( message )" ) + + " print('lua handler called with: ' .. message)" + + " return '" + handlederrormessage + "'" + "end" ); - sol::protected_function func = lua[ "doom" ]; - sol::protected_function luafunc = lua[ "luadoom" ]; - sol::function luahandler = lua[ "handler" ]; + sol::protected_function doom = lua[ "doom" ]; + sol::protected_function luadoom = lua[ "luadoom" ]; + sol::function luahandler = lua[ "luahandler" ]; sol::function cpphandler = lua[ "cpphandler" ]; - func.error_handler = luahandler; - luafunc.error_handler = cpphandler; - - sol::protected_function_result result1 = func(); - int test = lua_gettop(lua.lua_state()); - REQUIRE(!result1.valid()); - std::string errorstring = result1; - REQUIRE(errorstring == errormessage1); + doom.error_handler = luahandler; + luadoom.error_handler = cpphandler; - sol::protected_function_result result2 = luafunc(); - REQUIRE(!result2.valid()); - errorstring = result2; - REQUIRE(errorstring == errormessage2); + { + sol::protected_function_result result = doom(); + REQUIRE(!result.valid()); + std::string errorstring = result; + REQUIRE(errorstring == handlederrormessage); + } + { + sol::protected_function_result result = luadoom(); + REQUIRE(!result.valid()); + std::string errorstring = result; + REQUIRE(errorstring == handlederrormessage); + } } TEST_CASE("functions/destructor-tests", "Show that proper copies / destruction happens") { @@ -1413,11 +1418,9 @@ TEST_CASE("threading/coroutines", "ensure calling a coroutine works") { function loop() while counter ~= 30 do - print("Sending " .. counter); coroutine.yield(counter); counter = counter + 1; end - print("Sending " .. counter); return counter end )"; @@ -1429,8 +1432,8 @@ end int counter; for (counter = 20; counter < 31 && cr; ++counter) { - int x = cr(); - if (counter != x) { + int value = cr(); + if (counter != value) { throw std::logic_error("fuck"); } } @@ -1438,17 +1441,15 @@ end REQUIRE(counter == 30); } -TEST_CASE("threading/new-thread-coroutines", "ensure calling a coroutine works") { +TEST_CASE("threading/new-thread-coroutines", "ensure calling a coroutine works when the work is put on a different thread") { const auto& script = R"(counter = 20 function loop() while counter ~= 30 do - print("Sending " .. counter); coroutine.yield(counter); counter = counter + 1; end - print("Sending " .. counter); return counter end )"; @@ -1462,8 +1463,8 @@ end int counter; for (counter = 20; counter < 31 && cr; ++counter) { - int x = cr(); - if (counter != x) { + int value = cr(); + if (counter != value) { throw std::logic_error("fuck"); } }