diff --git a/sol/stack.hpp b/sol/stack.hpp index 55c1cacc..160a2515 100644 --- a/sol/stack.hpp +++ b/sol/stack.hpp @@ -127,6 +127,22 @@ struct getter { } }; +template +struct getter { + static T& get(lua_State* L, int index = -1) { + void* udata = lua_touserdata(L, index); + T** obj = static_cast(udata); + return **obj; + } +}; + +template +struct getter> { + static T& get(lua_State* L, int index = -1) { + return getter{}.get(L, index); + } +}; + template<> struct getter { static type get(lua_State *L, int index){ @@ -243,7 +259,7 @@ struct pusher { template, EnableIf>, Not>, Not>, Not>> = 0> static void push(lua_State* L, T& t) { - pusher{}.push(L, std::addressof(t)); + detail::push_userdata(L, userdata_traits::metatable, t); } template, EnableIf>, Not>, Not>, Not>> = 0> @@ -259,6 +275,13 @@ struct pusher { } }; +template +struct pusher> { + static void push(lua_State* L, const std::reference_wrapper& t) { + pusher{}.push(L, std::addressof(t.get())); + } +}; + template<> struct pusher { static void push(lua_State* L, const bool& b) { diff --git a/sol/tuple.hpp b/sol/tuple.hpp index 9d903ce0..670359fd 100644 --- a/sol/tuple.hpp +++ b/sol/tuple.hpp @@ -93,6 +93,9 @@ struct constructors {}; const auto default_constructor = constructors>{}; +template +using ref = std::reference_wrapper; + } // sol #endif // SOL_TUPLE_HPP diff --git a/tests.cpp b/tests.cpp index 3215036a..3bc3aeb0 100644 --- a/tests.cpp +++ b/tests.cpp @@ -84,6 +84,18 @@ struct self_test { } }; +struct vars { + vars () { + + } + + int boop = 0; + + ~vars () { + + } +}; + struct object { std::string operator() () { std::cout << "member_test()" << std::endl; @@ -469,7 +481,7 @@ TEST_CASE("functions/return_order_and_multi_get", "Check if return order is in t auto tluaget = lua.get("x", "y", "z"); std::cout << "cpp: " << std::get<0>(tcpp) << ',' << std::get<1>(tcpp) << ',' << std::get<2>(tcpp) << std::endl; std::cout << "lua: " << std::get<0>(tlua) << ',' << std::get<1>(tlua) << ',' << std::get<2>(tlua) << std::endl; - std::cout << "lua.xyz: " << lua.get("x") << ',' << lua.get("y") << ',' << lua.get("z") << std::endl; + std::cout << "lua xyz: " << lua.get("x") << ',' << lua.get("y") << ',' << lua.get("z") << std::endl; REQUIRE(tcpp == triple); REQUIRE(tlua == triple); REQUIRE(tluaget == triple); @@ -857,10 +869,6 @@ TEST_CASE("userdata/nonmember functions implement functionality", "let users set } TEST_CASE("regressions/one", "issue number 48") { - struct vars { - int boop = 0; - }; - sol::state lua; lua.new_userdata("vars", "boop", &vars::boop); REQUIRE_NOTHROW(lua.script("beep = vars.new()\n" @@ -871,3 +879,25 @@ TEST_CASE("regressions/one", "issue number 48") { auto* ptr = &my_var; REQUIRE(ptr->boop == 1); } + +TEST_CASE("references/get-set", "properly get and set with std::ref semantics. Note that to get, we must not use Unqualified on the type...") { + sol::state lua; + + lua.new_userdata("vars", + "boop", &vars::boop); + + vars var{}; + vars rvar{}; + lua.set("beep", var); + lua.set("rbeep", std::ref(rvar)); + auto& my_var = lua.get("beep"); + auto& ref_var = lua.get>("rbeep"); + + var.boop = 2; + rvar.boop = 5; + + REQUIRE((my_var.boop == 0)); + REQUIRE(var.boop != my_var.boop); + // Reference should point back to the same type. + REQUIRE(rvar.boop == ref_var.boop); +}