Proper trampolines to allow luajit to play nice with all the other kids in sol.

This commit is contained in:
ThePhD 2016-02-27 07:56:28 -05:00
parent e57ac87868
commit d42efd7fdf
11 changed files with 123 additions and 84 deletions

View File

@ -125,7 +125,10 @@ if 'linux' in sys.platform:
builddir = 'bin'
objdir = 'obj'
tests = os.path.join(builddir, 'tests')
if 'win32' in sys.platform:
tests = os.path.join(builddir, 'tests.exe')
else:
tests = os.path.join(builddir, 'tests')
# ninja file
ninja = ninja_syntax.Writer(open('build.ninja', 'w'))

View File

@ -170,6 +170,12 @@ private:
returncount = poststacksize - firstreturn;
}
// Handle C++ errors thrown from C++ functions bound inside of lua
catch (const char* error) {
h.stackindex = 0;
stack::push(lua_state(), error);
firstreturn = lua_gettop(lua_state());
return protected_function_result(lua_state(), firstreturn, 0, 1, call_status::runtime);
}
catch (const std::exception& error) {
h.stackindex = 0;
stack::push(lua_state(), error.what());
@ -296,7 +302,7 @@ struct pusher<function_sig<Sigs...>> {
dFx memfxptr(std::forward<Fx>(fx));
auto userptr = detail::ptr(obj);
void* userobjdata = static_cast<void*>(userptr);
lua_CFunction freefunc = &function_detail::static_member_function<std::decay_t<decltype(*userptr)>, uFx>::call;
lua_CFunction freefunc = &function_detail::upvalue_member_function<std::decay_t<decltype(*userptr)>, uFx>::call;
int upvalues = stack::stack_detail::push_as_upvalues(L, memfxptr);
upvalues += stack::push(L, userobjdata);
@ -307,7 +313,7 @@ struct pusher<function_sig<Sigs...>> {
template<typename Fx>
static void set_fx(std::false_type, lua_State* L, Fx&& fx) {
std::decay_t<Fx> target(std::forward<Fx>(fx));
lua_CFunction freefunc = &function_detail::static_function<Fx>::call;
lua_CFunction freefunc = &function_detail::upvalue_free_function<Fx>::call;
int upvalues = stack::stack_detail::push_as_upvalues(L, target);
stack::push(L, freefunc, upvalues);

View File

@ -75,9 +75,9 @@ inline int construct(lua_State* L) {
luaL_getmetatable(L, &meta[0]);
if (stack::get<type>(L) == type::nil) {
lua_pop(L, 1);
std::string err = "unable to get usertype metatable for ";
err += meta;
throw error(err);
std::string err = "sol: unable to get usertype metatable for ";
err += usertype_traits<T>::name;
return luaL_error(L, err.c_str());
}
lua_setmetatable(L, -2);
@ -126,9 +126,9 @@ struct usertype_constructor_function : base_function {
luaL_getmetatable(L, &meta[0]);
if (stack::get<type>(L) == type::nil) {
lua_pop(L, 1);
std::string err = "unable to get usertype metatable for ";
err += meta;
throw error(err);
std::string err = "sol: unable to get usertype metatable for ";
err += usertype_traits<T>::name;
return luaL_error(L, err.c_str());
}
lua_setmetatable(L, -2);

View File

@ -169,8 +169,8 @@ public:
};
struct base_function {
virtual int operator()(lua_State*) {
throw error("failure to call specialized wrapped C++ function from Lua");
virtual int operator()(lua_State* L) {
return luaL_error(L, "sol: failure to call specialized wrapped C++ function from Lua");
}
virtual ~base_function() {}
@ -178,18 +178,17 @@ struct base_function {
static int base_call(lua_State* L, void* inheritancedata) {
if (inheritancedata == nullptr) {
throw error("call from Lua to C++ function has null data");
return luaL_error(L, "sol: call from Lua to C++ function has null data");
}
base_function* pfx = static_cast<base_function*>(inheritancedata);
base_function& fx = *pfx;
int r = fx(L);
return r;
return detail::trampoline(L, fx);
}
static int base_gc(lua_State*, void* udata) {
static int base_gc(lua_State* L, void* udata) {
if (udata == nullptr) {
throw error("call from lua to C++ gc function with null data");
return luaL_error(L, "sol: call from lua to C++ gc function with null data");
}
base_function* ptr = static_cast<base_function*>(udata);

View File

@ -40,8 +40,8 @@ struct overload_traits<functor<T, Func, X>> {
};
template <std::size_t... M, typename Match, typename... Args>
inline int overload_match_arity(types<>, std::index_sequence<>, std::index_sequence<M...>, Match&&, lua_State*, int, int, Args&&...) {
throw error("no matching function call takes this number of arguments and the specified types");
inline int overload_match_arity(types<>, std::index_sequence<>, std::index_sequence<M...>, Match&&, lua_State* L, int, int, Args&&...) {
return luaL_error(L, "sol: no matching function call takes this number of arguments and the specified types");
}
template <typename Fx, typename... Fxs, std::size_t I, std::size_t... In, std::size_t... M, typename Match, typename... Args>
@ -124,7 +124,6 @@ struct usertype_overloaded_function : base_function {
auto mfx = [&](auto&&... args){ return this->call(std::forward<decltype(args)>(args)...); };
return overload_match<functor<T, std::remove_pointer_t<std::decay_t<Functions>>>...>(mfx, L, 2);
}
};
} // function_detail
} // sol

View File

@ -27,28 +27,32 @@
namespace sol {
namespace function_detail {
template<typename Function>
struct static_function {
struct upvalue_free_function {
typedef std::remove_pointer_t<std::decay_t<Function>> function_type;
typedef meta::function_traits<function_type> traits_type;
static int call(lua_State* L) {
static int real_call(lua_State* L) {
auto udata = stack::stack_detail::get_as_upvalues<function_type*>(L);
function_type* fx = udata.first;
int r = stack::call_into_lua(meta::tuple_types<typename traits_type::return_type>(), typename traits_type::args_type(), fx, L, 1);
return r;
}
static int call (lua_State* L) {
return detail::static_trampoline<&real_call>(L);
}
int operator()(lua_State* L) {
return call(L);
}
};
template<typename T, typename Function>
struct static_member_function {
struct upvalue_member_function {
typedef std::remove_pointer_t<std::decay_t<Function>> function_type;
typedef meta::function_traits<function_type> traits_type;
static int call(lua_State* L) {
static int real_call(lua_State* L) {
auto memberdata = stack::stack_detail::get_as_upvalues<function_type>(L, 1);
auto objdata = stack::stack_detail::get_as_upvalues<T*>(L, memberdata.second);
function_type& memfx = memberdata.first;
@ -57,6 +61,10 @@ struct static_member_function {
return stack::call_into_lua(meta::tuple_types<typename traits_type::return_type>(), typename traits_type::args_type(), fx, L, 1);
}
static int call (lua_State* L) {
return detail::static_trampoline<&real_call>(L);
}
int operator()(lua_State* L) {
return call(L);
}

View File

@ -95,7 +95,7 @@ struct usertype_function : public usertype_function_core<Function, Tp> {
int prelude(lua_State* L) {
this->fx.item = detail::ptr(stack::get<T>(L, 1));
if(this->fx.item == nullptr) {
throw error("userdata for function call is null: are you using the wrong syntax? (use item:function/variable(...) syntax)");
return luaL_error(L, "sol: userdata for function call is null: are you using the wrong syntax? (use item:function/variable(...) syntax)");
}
return static_cast<base_t&>(*this)(meta::tuple_types<return_type>(), args_type(), Index<2>(), L);
}
@ -120,7 +120,7 @@ struct usertype_variable_function : public usertype_function_core<Function, Tp>
int argcount = lua_gettop(L);
this->fx.item = stack::get<T*>(L, 1);
if(this->fx.item == nullptr) {
throw error("userdata for member variable is null");
return luaL_error(L, "sol: userdata for member variable is null");
}
switch(argcount) {
case 2:
@ -128,7 +128,7 @@ struct usertype_variable_function : public usertype_function_core<Function, Tp>
case 3:
return static_cast<base_t&>(*this)(meta::tuple_types<void>(), args_type(), Index<3>(), L);
default:
throw error("cannot get/set userdata member variable with inappropriate number of arguments");
return luaL_error(L, "sol: cannot get/set userdata member variable with inappropriate number of arguments");
}
}

View File

@ -315,7 +315,7 @@ template<>
struct getter<nil_t> {
static nil_t get(lua_State* L, int index = -1) {
if(lua_isnil(L, index) == 0) {
throw sol::error("not nil");
throw error("not nil");
}
return nil_t{ };
}
@ -796,22 +796,6 @@ inline void call(types<void>, types<Args...> ta, std::index_sequence<I...> tai,
check_arguments<checkargs>{}.check(ta, tai, L, start);
fx(std::forward<FxArgs>(args)..., stack::get<Args>(L, start + I)...);
}
inline int luajit_exception_jump (lua_State* L, lua_CFunction func) {
try {
return func(L);
}
catch (const char *s) { // Catch and convert exceptions.
lua_pushstring(L, s);
}
catch (const std::exception& e) {
lua_pushstring(L, e.what());
}
catch (...) {
lua_pushstring(L, "caught (...)");
}
return lua_error(L); // Rethrow as a Lua error.
}
} // stack_detail
inline void remove( lua_State* L, int index, int count ) {
@ -895,7 +879,7 @@ inline call_syntax get_call_syntax(lua_State* L, const std::string& meta) {
return call_syntax::dot;
}
inline void luajit_exception_handler(lua_State* L, int(*handler)(lua_State*, lua_CFunction) = stack_detail::luajit_exception_jump) {
inline void luajit_exception_handler(lua_State* L, int(*handler)(lua_State*, lua_CFunction) = detail::c_trampoline) {
#ifdef SOL_LUAJIT
lua_pushlightuserdata(L, (void*)handler);
luaJIT_setmode(L, -1, LUAJIT_MODE_WRAPCFUNC | LUAJIT_MODE_ON);

View File

@ -32,7 +32,6 @@ public:
state(lua_CFunction panic = detail::atpanic) : unique_base(luaL_newstate(), lua_close),
state_view(unique_base::get()) {
set_panic(panic);
sol::stack::luajit_exception_handler(unique_base::get());
}
using state_view::get;

View File

@ -27,6 +27,46 @@
#include <string>
namespace sol {
namespace detail {
template <lua_CFunction f>
inline int static_trampoline (lua_State* L) {
try {
return f(L);
}
catch (const char *s) { // Catch and convert exceptions.
lua_pushstring(L, s);
}
catch (const std::exception& e) {
lua_pushstring(L, e.what());
}
catch (...) {
lua_pushstring(L, "caught (...) exception");
}
return lua_error(L);
}
template <typename Fx>
inline int trampoline(lua_State* L, Fx&& f) {
try {
return f(L);
}
catch (const char *s) { // Catch and convert exceptions.
lua_pushstring(L, s);
}
catch (const std::exception& e) {
lua_pushstring(L, e.what());
}
catch (...) {
lua_pushstring(L, "caught (...) exception");
}
return lua_error(L);
}
inline int c_trampoline(lua_State* L, lua_CFunction f) {
return trampoline(L, f);
}
}
struct nil_t {};
const nil_t nil {};
inline bool operator==(nil_t, nil_t) { return true; }

View File

@ -1165,48 +1165,53 @@ TEST_CASE("interop/null-to-nil-and-back", "nil should be the given type when a p
TEST_CASE( "functions/function_result-protected_function_result", "Function result should be the beefy return type for sol::function that allows for error checking and error handlers" ) {
sol::state lua;
lua.open_libraries( sol::lib::base, sol::lib::debug );
static const char errormessage1[] = "true error message";
static const char errormessage2[] = "doodle";
static const char unhandlederrormessage[] = "true error message";
static const char handlederrormessage[] = "doodle";
// Some function; just using a lambda to be cheap
auto doom = []() {
// Bypasses handler function: puts information directly into lua error
throw std::runtime_error( errormessage1 );
auto doomfx = []() {
std::cout << "doomfx called" << std::endl;
throw std::runtime_error( unhandlederrormessage );
};
auto luadoom = [&lua]() {
auto luadoomfx = [&lua]() {
std::cout << "luadoomfx called" << std::endl;
// Does not bypass error function, will call it
luaL_error( lua.lua_state(), "BIG ERROR MESSAGES!" );
};
auto specialhandler = []( std::string ) {
return errormessage2;
luaL_error( lua.lua_state(), unhandlederrormessage );
};
lua.set_function("doom", doomfx);
lua.set_function("luadoom", luadoomfx);
lua.set_function( "doom", doom );
lua.set_function( "luadoom", luadoom );
lua.set_function( "cpphandler", specialhandler );
auto cpphandlerfx = []( std::string x ) {
std::cout << "c++ handler called with: " << x << std::endl;
return handlederrormessage;
};
lua.set_function( "cpphandler", cpphandlerfx );
lua.script(
std::string( "function handler ( message )" )
+ " return '" + errormessage2 + "'"
std::string( "function luahandler ( message )" )
+ " print('lua handler called with: ' .. message)"
+ " return '" + handlederrormessage + "'"
+ "end"
);
sol::protected_function func = lua[ "doom" ];
sol::protected_function luafunc = lua[ "luadoom" ];
sol::function luahandler = lua[ "handler" ];
sol::protected_function doom = lua[ "doom" ];
sol::protected_function luadoom = lua[ "luadoom" ];
sol::function luahandler = lua[ "luahandler" ];
sol::function cpphandler = lua[ "cpphandler" ];
func.error_handler = luahandler;
luafunc.error_handler = cpphandler;
sol::protected_function_result result1 = func();
int test = lua_gettop(lua.lua_state());
REQUIRE(!result1.valid());
std::string errorstring = result1;
REQUIRE(errorstring == errormessage1);
doom.error_handler = luahandler;
luadoom.error_handler = cpphandler;
sol::protected_function_result result2 = luafunc();
REQUIRE(!result2.valid());
errorstring = result2;
REQUIRE(errorstring == errormessage2);
{
sol::protected_function_result result = doom();
REQUIRE(!result.valid());
std::string errorstring = result;
REQUIRE(errorstring == handlederrormessage);
}
{
sol::protected_function_result result = luadoom();
REQUIRE(!result.valid());
std::string errorstring = result;
REQUIRE(errorstring == handlederrormessage);
}
}
TEST_CASE("functions/destructor-tests", "Show that proper copies / destruction happens") {
@ -1413,11 +1418,9 @@ TEST_CASE("threading/coroutines", "ensure calling a coroutine works") {
function loop()
while counter ~= 30
do
print("Sending " .. counter);
coroutine.yield(counter);
counter = counter + 1;
end
print("Sending " .. counter);
return counter
end
)";
@ -1429,8 +1432,8 @@ end
int counter;
for (counter = 20; counter < 31 && cr; ++counter) {
int x = cr();
if (counter != x) {
int value = cr();
if (counter != value) {
throw std::logic_error("fuck");
}
}
@ -1438,17 +1441,15 @@ end
REQUIRE(counter == 30);
}
TEST_CASE("threading/new-thread-coroutines", "ensure calling a coroutine works") {
TEST_CASE("threading/new-thread-coroutines", "ensure calling a coroutine works when the work is put on a different thread") {
const auto& script = R"(counter = 20
function loop()
while counter ~= 30
do
print("Sending " .. counter);
coroutine.yield(counter);
counter = counter + 1;
end
print("Sending " .. counter);
return counter
end
)";
@ -1462,8 +1463,8 @@ end
int counter;
for (counter = 20; counter < 31 && cr; ++counter) {
int x = cr();
if (counter != x) {
int value = cr();
if (counter != value) {
throw std::logic_error("fuck");
}
}