diff --git a/sol/function_types_overload.hpp b/sol/function_types_overload.hpp index 7b429178..4acf3d70 100644 --- a/sol/function_types_overload.hpp +++ b/sol/function_types_overload.hpp @@ -64,7 +64,7 @@ struct overloaded_function : base_function { if (!detail::check_types(args_type(), args_type(), L)) { return match_arity(L, x, indices()); } - return stack::typed_call(return_type(), args_type(), func, L); + return stack::typed_call(return_type(), args_type(), func, L); } int match_arity(lua_State* L) { @@ -120,7 +120,7 @@ struct usertype_overloaded_function : base_function { return match_arity(L, x, indices()); } func.item = ptr(stack::get(L, 1)); - return stack::typed_call(return_type(), args_type(), func, L); + return stack::typed_call(return_type(), args_type(), func, L); } int match_arity(lua_State* L) { @@ -175,7 +175,7 @@ struct usertype_indexing_function, T> : base_function return match_arity(L, x, indices()); } func.item = ptr(stack::get(L, 1)); - return stack::typed_call(return_type(), args_type(), func, L); + return stack::typed_call(return_type(), args_type(), func, L); } int match_arity(lua_State* L) { diff --git a/sol/stack.hpp b/sol/stack.hpp index fcb3e648..557be067 100644 --- a/sol/stack.hpp +++ b/sol/stack.hpp @@ -203,18 +203,46 @@ struct checker { } }; -template -struct checker { +template +struct checker { template static bool check (lua_State* L, int index, const Handler& handler) { const type indextype = type_of(L, index); // Allow nil to be transformed to nullptr - bool success = expected == indextype || indextype == type::nil; - if (!success) { - // expected type, actual type - handler(L, index, expected, indextype); + if (indextype == type::nil) { + return true; } - return success; + return checker{}.check(L, indextype, index, handler); + } +}; + +template +struct checker { + template + static bool check (lua_State* L, type indextype, int index, const Handler& handler) { + if (indextype != type::userdata) { + handler(L, index, type::userdata, indextype); + return false; + } + if (lua_getmetatable(L, index) == 0) { + handler(L, index, type::userdata, indextype); + return false; + } + const type expectedmetatabletype = static_cast(luaL_getmetatable(L, &usertype_traits::metatable[0])); + if (expectedmetatabletype == type::nil) { + lua_pop(L, 2); + handler(L, index, type::userdata, indextype); + return false; + } + bool success = lua_rawequal(L, -1, -2) == 1; + lua_pop(L, 2); + return success; + } + + template + static bool check (lua_State* L, int index, const Handler& handler) { + const type indextype = type_of(L, index); + return check(L, indextype, index, handler); } }; @@ -628,17 +656,17 @@ inline void call(lua_State* L, types tr, types ta, Fx&& fx, FxArg call(L, 0, ta, tr, ta, std::forward(fx), std::forward(args)...); } -template +template inline int typed_call(types tr, types ta, Fx&& fx, lua_State* L) { - stack::call(L, 0, tr, ta, fx); + stack::call(L, 0, tr, ta, fx); int nargs = static_cast(sizeof...(Args)); lua_pop(L, nargs); return 0; } -template +template inline int typed_call(types tr, types ta, Fx&& fx, lua_State* L) { - decltype(auto) r = stack::call(L, 0, tr, ta, fx); + decltype(auto) r = stack::call(L, 0, tr, ta, fx); int nargs = static_cast(sizeof...(Args)); lua_pop(L, nargs); return stack::push(L, std::forward(r));