diff --git a/sol/stack_check.hpp b/sol/stack_check.hpp index b92bca89..936c2ece 100644 --- a/sol/stack_check.hpp +++ b/sol/stack_check.hpp @@ -429,8 +429,16 @@ namespace sol { return true; } int metatableindex = lua_gettop(L); - if (stack_detail::check_metatable>(L, metatableindex)) - return true; + if (stack_detail::check_metatable>(L, metatableindex)) { + void* memory = lua_touserdata(L, 1); + T** pointerpointer = static_cast(memory); + detail::unique_destructor& pdx = *static_cast(static_cast(pointerpointer + 1)); + bool success = &detail::usertype_unique_alloc_destroy == pdx; + if (!success) { + handler(L, index, type::userdata, indextype); + } + return success; + } lua_pop(L, 1); handler(L, index, type::userdata, indextype); return false; diff --git a/sol/stack_core.hpp b/sol/stack_core.hpp index 40ae9051..40c349f2 100644 --- a/sol/stack_core.hpp +++ b/sol/stack_core.hpp @@ -40,25 +40,26 @@ namespace sol { template struct as_value_tag {}; - using special_destruct_func = void(*)(void*); - - template - inline void special_destruct(void* memory) { - T** pointerpointer = static_cast(memory); - special_destruct_func* dx = static_cast(static_cast(pointerpointer + 1)); - Real* target = static_cast(static_cast(dx + 1)); - target->~Real(); - } + using unique_destructor = void(*)(void*); template inline int unique_destruct(lua_State* L) { void* memory = lua_touserdata(L, 1); T** pointerpointer = static_cast(memory); - special_destruct_func& dx = *static_cast(static_cast(pointerpointer + 1)); + unique_destructor& dx = *static_cast(static_cast(pointerpointer + 1)); (dx)(memory); return 0; } + template + inline void usertype_unique_alloc_destroy(void* memory) { + T** pointerpointer = static_cast(memory); + unique_destructor* dx = static_cast(static_cast(pointerpointer + 1)); + Real* target = static_cast(static_cast(dx + 1)); + std::allocator alloc; + alloc.destroy(target); + } + template inline int user_alloc_destroy(lua_State* L) { void* rawdata = lua_touserdata(L, 1); diff --git a/sol/stack_get.hpp b/sol/stack_get.hpp index 6bad25a3..e0394d9e 100644 --- a/sol/stack_get.hpp +++ b/sol/stack_get.hpp @@ -567,7 +567,7 @@ namespace sol { static Real& get(lua_State* L, int index, record& tracking) { tracking.use(1); P** pref = static_cast(lua_touserdata(L, index)); - detail::special_destruct_func* fx = static_cast(static_cast(pref + 1)); + detail::unique_destructor* fx = static_cast(static_cast(pref + 1)); Real* mem = static_cast(static_cast(fx + 1)); return *mem; } diff --git a/sol/stack_push.hpp b/sol/stack_push.hpp index 17909eb6..caafceda 100644 --- a/sol/stack_push.hpp +++ b/sol/stack_push.hpp @@ -153,10 +153,10 @@ namespace sol { template static int push_deep(lua_State* L, Args&&... args) { - P** pref = static_cast(lua_newuserdata(L, sizeof(P*) + sizeof(detail::special_destruct_func) + sizeof(Real))); - detail::special_destruct_func* fx = static_cast(static_cast(pref + 1)); + P** pref = static_cast(lua_newuserdata(L, sizeof(P*) + sizeof(detail::unique_destructor) + sizeof(Real))); + detail::unique_destructor* fx = static_cast(static_cast(pref + 1)); Real* mem = static_cast(static_cast(fx + 1)); - *fx = detail::special_destruct; + *fx = detail::usertype_unique_alloc_destroy; detail::default_construct::construct(mem, std::forward(args)...); *pref = unique_usertype_traits::get(*mem); if (luaL_newmetatable(L, &usertype_traits>::metatable()[0]) == 1) { diff --git a/sol/usertype_metatable.hpp b/sol/usertype_metatable.hpp index a314f93f..196f3a4c 100644 --- a/sol/usertype_metatable.hpp +++ b/sol/usertype_metatable.hpp @@ -746,8 +746,8 @@ namespace sol { bool hasdestructor = !value_table.empty() && to_string(meta_function::garbage_collect) == value_table[lastreg - 1].name; if (hasdestructor) { ref_table[lastreg - 1] = { nullptr, nullptr }; - unique_table[lastreg - 1] = { value_table[lastreg - 1].name, detail::unique_destruct }; } + unique_table[lastreg - 1] = { value_table[lastreg - 1].name, detail::unique_destruct }; // Now use um const bool& mustindex = umc.mustindex; diff --git a/test_functions.cpp b/test_functions.cpp index aa287a97..5f06777e 100644 --- a/test_functions.cpp +++ b/test_functions.cpp @@ -1137,35 +1137,88 @@ TEST_CASE("functions/set_function-already-wrapped", "setting a function returned } TEST_CASE("functions/unique-overloading", "make sure overloading can work with ptr vs. specifically asking for a unique usertype") { - sol::state lua; - - struct test { int special_value = 17; }; - auto print_up_test = [](std::unique_ptr& x) { - REQUIRE(x->special_value == 17); + struct test { + int special_value = 17; + test() : special_value(17) {} + test(int special_value) : special_value(special_value) {} + }; + auto print_up_test = [](std::unique_ptr& x) { + REQUIRE(x->special_value == 21); + }; + auto print_sp_test = [](std::shared_ptr& x) { + REQUIRE(x->special_value == 44); }; - auto print_ptr_test = [](test* x) { REQUIRE(x->special_value == 17); }; + auto print_ref_test = [](test& x) { + bool is_any = x.special_value == 17 + || x.special_value == 21 + || x. special_value == 44; + REQUIRE(is_any); + }; + using f_t = void(test&); + f_t* fptr = print_ref_test; + + std::unique_ptr ut = std::make_unique(17); + SECTION("working") { + sol::state lua; - lua.set_function("f", print_up_test); - lua.set_function("g", sol::overload( - std::ref(print_up_test), - print_ptr_test - )); + lua.set_function("f", print_up_test); + lua.set_function("g", sol::overload( + std::move(print_sp_test), + print_up_test, + std::ref(print_ptr_test) + )); + lua.set_function("h", std::ref(fptr)); - lua["v1"] = std::make_unique(); - lua["v2"] = test{}; - REQUIRE_NOTHROW([&]() { - lua.script("g(v1)"); - }()); - REQUIRE_NOTHROW([&]() { - lua.script("g(v2)"); - }()); - REQUIRE_NOTHROW([&]() { - lua.script("f(v1)"); - }()); - REQUIRE_THROWS([&]() { - lua.script("f(v2)"); - }()); + lua["v1"] = std::make_unique(21); + lua["v2"] = std::make_shared(44); + lua["v3"] = test(17); + lua["v4"] = ut.get(); + + REQUIRE_NOTHROW([&]() { + lua.script("f(v1)"); + lua.script("g(v1)"); + lua.script("g(v2)"); + lua.script("g(v3)"); + lua.script("g(v4)"); + lua.script("h(v1)"); + lua.script("h(v2)"); + lua.script("h(v3)"); + lua.script("h(v4)"); + }()); + }; + // LuaJIT segfaults hard on some Linux machines + // and it breaks all the tests... + SECTION("throws-value") { + sol::state lua; + + lua.set_function("f", print_up_test); + lua["v3"] = test(17); + + REQUIRE_THROWS([&]() { + lua.script("f(v3)"); + }()); + }; + SECTION("throws-shared_ptr") { + sol::state lua; + + lua.set_function("f", print_up_test); + lua["v2"] = std::make_shared(44); + + REQUIRE_THROWS([&]() { + lua.script("f(v2)"); + }()); + }; + SECTION("throws-ptr") { + sol::state lua; + + lua.set_function("f", print_up_test); + lua["v4"] = ut.get(); + + REQUIRE_THROWS([&]() { + lua.script("f(v4)"); + }()); + }; }