Proper trampolines to allow luajit to play nice with all the other kids in sol.

This commit is contained in:
ThePhD 2016-02-27 07:56:28 -05:00
parent e57ac87868
commit d42efd7fdf
11 changed files with 123 additions and 84 deletions

View File

@ -125,6 +125,9 @@ if 'linux' in sys.platform:
builddir = 'bin' builddir = 'bin'
objdir = 'obj' objdir = 'obj'
if 'win32' in sys.platform:
tests = os.path.join(builddir, 'tests.exe')
else:
tests = os.path.join(builddir, 'tests') tests = os.path.join(builddir, 'tests')
# ninja file # ninja file

View File

@ -170,6 +170,12 @@ private:
returncount = poststacksize - firstreturn; returncount = poststacksize - firstreturn;
} }
// Handle C++ errors thrown from C++ functions bound inside of lua // Handle C++ errors thrown from C++ functions bound inside of lua
catch (const char* error) {
h.stackindex = 0;
stack::push(lua_state(), error);
firstreturn = lua_gettop(lua_state());
return protected_function_result(lua_state(), firstreturn, 0, 1, call_status::runtime);
}
catch (const std::exception& error) { catch (const std::exception& error) {
h.stackindex = 0; h.stackindex = 0;
stack::push(lua_state(), error.what()); stack::push(lua_state(), error.what());
@ -296,7 +302,7 @@ struct pusher<function_sig<Sigs...>> {
dFx memfxptr(std::forward<Fx>(fx)); dFx memfxptr(std::forward<Fx>(fx));
auto userptr = detail::ptr(obj); auto userptr = detail::ptr(obj);
void* userobjdata = static_cast<void*>(userptr); void* userobjdata = static_cast<void*>(userptr);
lua_CFunction freefunc = &function_detail::static_member_function<std::decay_t<decltype(*userptr)>, uFx>::call; lua_CFunction freefunc = &function_detail::upvalue_member_function<std::decay_t<decltype(*userptr)>, uFx>::call;
int upvalues = stack::stack_detail::push_as_upvalues(L, memfxptr); int upvalues = stack::stack_detail::push_as_upvalues(L, memfxptr);
upvalues += stack::push(L, userobjdata); upvalues += stack::push(L, userobjdata);
@ -307,7 +313,7 @@ struct pusher<function_sig<Sigs...>> {
template<typename Fx> template<typename Fx>
static void set_fx(std::false_type, lua_State* L, Fx&& fx) { static void set_fx(std::false_type, lua_State* L, Fx&& fx) {
std::decay_t<Fx> target(std::forward<Fx>(fx)); std::decay_t<Fx> target(std::forward<Fx>(fx));
lua_CFunction freefunc = &function_detail::static_function<Fx>::call; lua_CFunction freefunc = &function_detail::upvalue_free_function<Fx>::call;
int upvalues = stack::stack_detail::push_as_upvalues(L, target); int upvalues = stack::stack_detail::push_as_upvalues(L, target);
stack::push(L, freefunc, upvalues); stack::push(L, freefunc, upvalues);

View File

@ -75,9 +75,9 @@ inline int construct(lua_State* L) {
luaL_getmetatable(L, &meta[0]); luaL_getmetatable(L, &meta[0]);
if (stack::get<type>(L) == type::nil) { if (stack::get<type>(L) == type::nil) {
lua_pop(L, 1); lua_pop(L, 1);
std::string err = "unable to get usertype metatable for "; std::string err = "sol: unable to get usertype metatable for ";
err += meta; err += usertype_traits<T>::name;
throw error(err); return luaL_error(L, err.c_str());
} }
lua_setmetatable(L, -2); lua_setmetatable(L, -2);
@ -126,9 +126,9 @@ struct usertype_constructor_function : base_function {
luaL_getmetatable(L, &meta[0]); luaL_getmetatable(L, &meta[0]);
if (stack::get<type>(L) == type::nil) { if (stack::get<type>(L) == type::nil) {
lua_pop(L, 1); lua_pop(L, 1);
std::string err = "unable to get usertype metatable for "; std::string err = "sol: unable to get usertype metatable for ";
err += meta; err += usertype_traits<T>::name;
throw error(err); return luaL_error(L, err.c_str());
} }
lua_setmetatable(L, -2); lua_setmetatable(L, -2);

View File

@ -169,8 +169,8 @@ public:
}; };
struct base_function { struct base_function {
virtual int operator()(lua_State*) { virtual int operator()(lua_State* L) {
throw error("failure to call specialized wrapped C++ function from Lua"); return luaL_error(L, "sol: failure to call specialized wrapped C++ function from Lua");
} }
virtual ~base_function() {} virtual ~base_function() {}
@ -178,18 +178,17 @@ struct base_function {
static int base_call(lua_State* L, void* inheritancedata) { static int base_call(lua_State* L, void* inheritancedata) {
if (inheritancedata == nullptr) { if (inheritancedata == nullptr) {
throw error("call from Lua to C++ function has null data"); return luaL_error(L, "sol: call from Lua to C++ function has null data");
} }
base_function* pfx = static_cast<base_function*>(inheritancedata); base_function* pfx = static_cast<base_function*>(inheritancedata);
base_function& fx = *pfx; base_function& fx = *pfx;
int r = fx(L); return detail::trampoline(L, fx);
return r;
} }
static int base_gc(lua_State*, void* udata) { static int base_gc(lua_State* L, void* udata) {
if (udata == nullptr) { if (udata == nullptr) {
throw error("call from lua to C++ gc function with null data"); return luaL_error(L, "sol: call from lua to C++ gc function with null data");
} }
base_function* ptr = static_cast<base_function*>(udata); base_function* ptr = static_cast<base_function*>(udata);

View File

@ -40,8 +40,8 @@ struct overload_traits<functor<T, Func, X>> {
}; };
template <std::size_t... M, typename Match, typename... Args> template <std::size_t... M, typename Match, typename... Args>
inline int overload_match_arity(types<>, std::index_sequence<>, std::index_sequence<M...>, Match&&, lua_State*, int, int, Args&&...) { inline int overload_match_arity(types<>, std::index_sequence<>, std::index_sequence<M...>, Match&&, lua_State* L, int, int, Args&&...) {
throw error("no matching function call takes this number of arguments and the specified types"); return luaL_error(L, "sol: no matching function call takes this number of arguments and the specified types");
} }
template <typename Fx, typename... Fxs, std::size_t I, std::size_t... In, std::size_t... M, typename Match, typename... Args> template <typename Fx, typename... Fxs, std::size_t I, std::size_t... In, std::size_t... M, typename Match, typename... Args>
@ -124,7 +124,6 @@ struct usertype_overloaded_function : base_function {
auto mfx = [&](auto&&... args){ return this->call(std::forward<decltype(args)>(args)...); }; auto mfx = [&](auto&&... args){ return this->call(std::forward<decltype(args)>(args)...); };
return overload_match<functor<T, std::remove_pointer_t<std::decay_t<Functions>>>...>(mfx, L, 2); return overload_match<functor<T, std::remove_pointer_t<std::decay_t<Functions>>>...>(mfx, L, 2);
} }
}; };
} // function_detail } // function_detail
} // sol } // sol

View File

@ -27,28 +27,32 @@
namespace sol { namespace sol {
namespace function_detail { namespace function_detail {
template<typename Function> template<typename Function>
struct static_function { struct upvalue_free_function {
typedef std::remove_pointer_t<std::decay_t<Function>> function_type; typedef std::remove_pointer_t<std::decay_t<Function>> function_type;
typedef meta::function_traits<function_type> traits_type; typedef meta::function_traits<function_type> traits_type;
static int call(lua_State* L) { static int real_call(lua_State* L) {
auto udata = stack::stack_detail::get_as_upvalues<function_type*>(L); auto udata = stack::stack_detail::get_as_upvalues<function_type*>(L);
function_type* fx = udata.first; function_type* fx = udata.first;
int r = stack::call_into_lua(meta::tuple_types<typename traits_type::return_type>(), typename traits_type::args_type(), fx, L, 1); int r = stack::call_into_lua(meta::tuple_types<typename traits_type::return_type>(), typename traits_type::args_type(), fx, L, 1);
return r; return r;
} }
static int call (lua_State* L) {
return detail::static_trampoline<&real_call>(L);
}
int operator()(lua_State* L) { int operator()(lua_State* L) {
return call(L); return call(L);
} }
}; };
template<typename T, typename Function> template<typename T, typename Function>
struct static_member_function { struct upvalue_member_function {
typedef std::remove_pointer_t<std::decay_t<Function>> function_type; typedef std::remove_pointer_t<std::decay_t<Function>> function_type;
typedef meta::function_traits<function_type> traits_type; typedef meta::function_traits<function_type> traits_type;
static int call(lua_State* L) { static int real_call(lua_State* L) {
auto memberdata = stack::stack_detail::get_as_upvalues<function_type>(L, 1); auto memberdata = stack::stack_detail::get_as_upvalues<function_type>(L, 1);
auto objdata = stack::stack_detail::get_as_upvalues<T*>(L, memberdata.second); auto objdata = stack::stack_detail::get_as_upvalues<T*>(L, memberdata.second);
function_type& memfx = memberdata.first; function_type& memfx = memberdata.first;
@ -57,6 +61,10 @@ struct static_member_function {
return stack::call_into_lua(meta::tuple_types<typename traits_type::return_type>(), typename traits_type::args_type(), fx, L, 1); return stack::call_into_lua(meta::tuple_types<typename traits_type::return_type>(), typename traits_type::args_type(), fx, L, 1);
} }
static int call (lua_State* L) {
return detail::static_trampoline<&real_call>(L);
}
int operator()(lua_State* L) { int operator()(lua_State* L) {
return call(L); return call(L);
} }

View File

@ -95,7 +95,7 @@ struct usertype_function : public usertype_function_core<Function, Tp> {
int prelude(lua_State* L) { int prelude(lua_State* L) {
this->fx.item = detail::ptr(stack::get<T>(L, 1)); this->fx.item = detail::ptr(stack::get<T>(L, 1));
if(this->fx.item == nullptr) { if(this->fx.item == nullptr) {
throw error("userdata for function call is null: are you using the wrong syntax? (use item:function/variable(...) syntax)"); return luaL_error(L, "sol: userdata for function call is null: are you using the wrong syntax? (use item:function/variable(...) syntax)");
} }
return static_cast<base_t&>(*this)(meta::tuple_types<return_type>(), args_type(), Index<2>(), L); return static_cast<base_t&>(*this)(meta::tuple_types<return_type>(), args_type(), Index<2>(), L);
} }
@ -120,7 +120,7 @@ struct usertype_variable_function : public usertype_function_core<Function, Tp>
int argcount = lua_gettop(L); int argcount = lua_gettop(L);
this->fx.item = stack::get<T*>(L, 1); this->fx.item = stack::get<T*>(L, 1);
if(this->fx.item == nullptr) { if(this->fx.item == nullptr) {
throw error("userdata for member variable is null"); return luaL_error(L, "sol: userdata for member variable is null");
} }
switch(argcount) { switch(argcount) {
case 2: case 2:
@ -128,7 +128,7 @@ struct usertype_variable_function : public usertype_function_core<Function, Tp>
case 3: case 3:
return static_cast<base_t&>(*this)(meta::tuple_types<void>(), args_type(), Index<3>(), L); return static_cast<base_t&>(*this)(meta::tuple_types<void>(), args_type(), Index<3>(), L);
default: default:
throw error("cannot get/set userdata member variable with inappropriate number of arguments"); return luaL_error(L, "sol: cannot get/set userdata member variable with inappropriate number of arguments");
} }
} }

View File

@ -315,7 +315,7 @@ template<>
struct getter<nil_t> { struct getter<nil_t> {
static nil_t get(lua_State* L, int index = -1) { static nil_t get(lua_State* L, int index = -1) {
if(lua_isnil(L, index) == 0) { if(lua_isnil(L, index) == 0) {
throw sol::error("not nil"); throw error("not nil");
} }
return nil_t{ }; return nil_t{ };
} }
@ -796,22 +796,6 @@ inline void call(types<void>, types<Args...> ta, std::index_sequence<I...> tai,
check_arguments<checkargs>{}.check(ta, tai, L, start); check_arguments<checkargs>{}.check(ta, tai, L, start);
fx(std::forward<FxArgs>(args)..., stack::get<Args>(L, start + I)...); fx(std::forward<FxArgs>(args)..., stack::get<Args>(L, start + I)...);
} }
inline int luajit_exception_jump (lua_State* L, lua_CFunction func) {
try {
return func(L);
}
catch (const char *s) { // Catch and convert exceptions.
lua_pushstring(L, s);
}
catch (const std::exception& e) {
lua_pushstring(L, e.what());
}
catch (...) {
lua_pushstring(L, "caught (...)");
}
return lua_error(L); // Rethrow as a Lua error.
}
} // stack_detail } // stack_detail
inline void remove( lua_State* L, int index, int count ) { inline void remove( lua_State* L, int index, int count ) {
@ -895,7 +879,7 @@ inline call_syntax get_call_syntax(lua_State* L, const std::string& meta) {
return call_syntax::dot; return call_syntax::dot;
} }
inline void luajit_exception_handler(lua_State* L, int(*handler)(lua_State*, lua_CFunction) = stack_detail::luajit_exception_jump) { inline void luajit_exception_handler(lua_State* L, int(*handler)(lua_State*, lua_CFunction) = detail::c_trampoline) {
#ifdef SOL_LUAJIT #ifdef SOL_LUAJIT
lua_pushlightuserdata(L, (void*)handler); lua_pushlightuserdata(L, (void*)handler);
luaJIT_setmode(L, -1, LUAJIT_MODE_WRAPCFUNC | LUAJIT_MODE_ON); luaJIT_setmode(L, -1, LUAJIT_MODE_WRAPCFUNC | LUAJIT_MODE_ON);

View File

@ -32,7 +32,6 @@ public:
state(lua_CFunction panic = detail::atpanic) : unique_base(luaL_newstate(), lua_close), state(lua_CFunction panic = detail::atpanic) : unique_base(luaL_newstate(), lua_close),
state_view(unique_base::get()) { state_view(unique_base::get()) {
set_panic(panic); set_panic(panic);
sol::stack::luajit_exception_handler(unique_base::get());
} }
using state_view::get; using state_view::get;

View File

@ -27,6 +27,46 @@
#include <string> #include <string>
namespace sol { namespace sol {
namespace detail {
template <lua_CFunction f>
inline int static_trampoline (lua_State* L) {
try {
return f(L);
}
catch (const char *s) { // Catch and convert exceptions.
lua_pushstring(L, s);
}
catch (const std::exception& e) {
lua_pushstring(L, e.what());
}
catch (...) {
lua_pushstring(L, "caught (...) exception");
}
return lua_error(L);
}
template <typename Fx>
inline int trampoline(lua_State* L, Fx&& f) {
try {
return f(L);
}
catch (const char *s) { // Catch and convert exceptions.
lua_pushstring(L, s);
}
catch (const std::exception& e) {
lua_pushstring(L, e.what());
}
catch (...) {
lua_pushstring(L, "caught (...) exception");
}
return lua_error(L);
}
inline int c_trampoline(lua_State* L, lua_CFunction f) {
return trampoline(L, f);
}
}
struct nil_t {}; struct nil_t {};
const nil_t nil {}; const nil_t nil {};
inline bool operator==(nil_t, nil_t) { return true; } inline bool operator==(nil_t, nil_t) { return true; }

View File

@ -1165,48 +1165,53 @@ TEST_CASE("interop/null-to-nil-and-back", "nil should be the given type when a p
TEST_CASE( "functions/function_result-protected_function_result", "Function result should be the beefy return type for sol::function that allows for error checking and error handlers" ) { TEST_CASE( "functions/function_result-protected_function_result", "Function result should be the beefy return type for sol::function that allows for error checking and error handlers" ) {
sol::state lua; sol::state lua;
lua.open_libraries( sol::lib::base, sol::lib::debug ); lua.open_libraries( sol::lib::base, sol::lib::debug );
static const char errormessage1[] = "true error message"; static const char unhandlederrormessage[] = "true error message";
static const char errormessage2[] = "doodle"; static const char handlederrormessage[] = "doodle";
// Some function; just using a lambda to be cheap // Some function; just using a lambda to be cheap
auto doom = []() { auto doomfx = []() {
// Bypasses handler function: puts information directly into lua error std::cout << "doomfx called" << std::endl;
throw std::runtime_error( errormessage1 ); throw std::runtime_error( unhandlederrormessage );
}; };
auto luadoom = [&lua]() { auto luadoomfx = [&lua]() {
std::cout << "luadoomfx called" << std::endl;
// Does not bypass error function, will call it // Does not bypass error function, will call it
luaL_error( lua.lua_state(), "BIG ERROR MESSAGES!" ); luaL_error( lua.lua_state(), unhandlederrormessage );
};
auto specialhandler = []( std::string ) {
return errormessage2;
}; };
lua.set_function("doom", doomfx);
lua.set_function("luadoom", luadoomfx);
lua.set_function( "doom", doom ); auto cpphandlerfx = []( std::string x ) {
lua.set_function( "luadoom", luadoom ); std::cout << "c++ handler called with: " << x << std::endl;
lua.set_function( "cpphandler", specialhandler ); return handlederrormessage;
};
lua.set_function( "cpphandler", cpphandlerfx );
lua.script( lua.script(
std::string( "function handler ( message )" ) std::string( "function luahandler ( message )" )
+ " return '" + errormessage2 + "'" + " print('lua handler called with: ' .. message)"
+ " return '" + handlederrormessage + "'"
+ "end" + "end"
); );
sol::protected_function func = lua[ "doom" ]; sol::protected_function doom = lua[ "doom" ];
sol::protected_function luafunc = lua[ "luadoom" ]; sol::protected_function luadoom = lua[ "luadoom" ];
sol::function luahandler = lua[ "handler" ]; sol::function luahandler = lua[ "luahandler" ];
sol::function cpphandler = lua[ "cpphandler" ]; sol::function cpphandler = lua[ "cpphandler" ];
func.error_handler = luahandler; doom.error_handler = luahandler;
luafunc.error_handler = cpphandler; luadoom.error_handler = cpphandler;
sol::protected_function_result result1 = func(); {
int test = lua_gettop(lua.lua_state()); sol::protected_function_result result = doom();
REQUIRE(!result1.valid()); REQUIRE(!result.valid());
std::string errorstring = result1; std::string errorstring = result;
REQUIRE(errorstring == errormessage1); REQUIRE(errorstring == handlederrormessage);
}
sol::protected_function_result result2 = luafunc(); {
REQUIRE(!result2.valid()); sol::protected_function_result result = luadoom();
errorstring = result2; REQUIRE(!result.valid());
REQUIRE(errorstring == errormessage2); std::string errorstring = result;
REQUIRE(errorstring == handlederrormessage);
}
} }
TEST_CASE("functions/destructor-tests", "Show that proper copies / destruction happens") { TEST_CASE("functions/destructor-tests", "Show that proper copies / destruction happens") {
@ -1413,11 +1418,9 @@ TEST_CASE("threading/coroutines", "ensure calling a coroutine works") {
function loop() function loop()
while counter ~= 30 while counter ~= 30
do do
print("Sending " .. counter);
coroutine.yield(counter); coroutine.yield(counter);
counter = counter + 1; counter = counter + 1;
end end
print("Sending " .. counter);
return counter return counter
end end
)"; )";
@ -1429,8 +1432,8 @@ end
int counter; int counter;
for (counter = 20; counter < 31 && cr; ++counter) { for (counter = 20; counter < 31 && cr; ++counter) {
int x = cr(); int value = cr();
if (counter != x) { if (counter != value) {
throw std::logic_error("fuck"); throw std::logic_error("fuck");
} }
} }
@ -1438,17 +1441,15 @@ end
REQUIRE(counter == 30); REQUIRE(counter == 30);
} }
TEST_CASE("threading/new-thread-coroutines", "ensure calling a coroutine works") { TEST_CASE("threading/new-thread-coroutines", "ensure calling a coroutine works when the work is put on a different thread") {
const auto& script = R"(counter = 20 const auto& script = R"(counter = 20
function loop() function loop()
while counter ~= 30 while counter ~= 30
do do
print("Sending " .. counter);
coroutine.yield(counter); coroutine.yield(counter);
counter = counter + 1; counter = counter + 1;
end end
print("Sending " .. counter);
return counter return counter
end end
)"; )";
@ -1462,8 +1463,8 @@ end
int counter; int counter;
for (counter = 20; counter < 31 && cr; ++counter) { for (counter = 20; counter < 31 && cr; ++counter) {
int x = cr(); int value = cr();
if (counter != x) { if (counter != value) {
throw std::logic_error("fuck"); throw std::logic_error("fuck");
} }
} }