From dcff5cdaa07985c00c703c408ff4cd98ee38f324 Mon Sep 17 00:00:00 2001
From: ThePhD <phdofthehouse@gmail.com>
Date: Wed, 13 Sep 2017 10:20:24 -0400
Subject: [PATCH] re-implement xmove implicit transfers and hope it works
 proper this time

---
 single/sol/sol.hpp        | 106 ++++++++++++++++++++++++++------------
 sol/reference.hpp         |  89 +++++++++++++++++++++-----------
 sol/stack_reference.hpp   |  13 ++++-
 tests/test_coroutines.cpp |  89 +++++++++++++++++++++++++++++++-
 4 files changed, 234 insertions(+), 63 deletions(-)

diff --git a/single/sol/sol.hpp b/single/sol/sol.hpp
index f10a72f2..eec16bae 100644
--- a/single/sol/sol.hpp
+++ b/single/sol/sol.hpp
@@ -20,8 +20,8 @@
 // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 
 // This file was generated with a script.
-// Generated 2017-09-13 06:45:53.519071 UTC
-// This header was generated with sol v2.18.2 (revision 66eb025)
+// Generated 2017-09-13 14:18:42.960702 UTC
+// This header was generated with sol v2.18.2 (revision 5816c6c)
 // https://github.com/ThePhD/sol2
 
 #ifndef SOL_SINGLE_INCLUDE_HPP
@@ -5672,6 +5672,17 @@ namespace sol {
 // beginning of sol/stack_reference.hpp
 
 namespace sol {
+	namespace detail {
+		inline bool xmovable(lua_State* leftL, lua_State* rightL) {
+			if (rightL == nullptr || leftL == nullptr || leftL == rightL) {
+				return false;
+			}
+			const void* leftregistry = lua_topointer(leftL, LUA_REGISTRYINDEX);
+			const void* rightregistry = lua_topointer(rightL, LUA_REGISTRYINDEX);
+			return leftregistry == rightregistry;
+		}
+	} // namespace detail
+
 	class stack_reference {
 	private:
 		lua_State* luastate = nullptr;
@@ -5707,7 +5718,7 @@ namespace sol {
 				return;
 			}
 			int i = r.stack_index();
-			if (r.lua_state() != luastate) {
+			if (detail::xmovable(lua_state(), r.lua_state())) {
 				lua_pushvalue(r.lua_state(), r.index);
 				lua_xmove(r.lua_state(), luastate, 1);
 				i = absolute_index(luastate, -1);
@@ -5931,40 +5942,54 @@ namespace sol {
 		}
 		reference(lua_State* L, const reference& r) noexcept
 		: luastate(L) {
-			if (r.ref == LUA_NOREF) {
+			if (r.ref == LUA_REFNIL) {
+				ref = LUA_REFNIL;
+				return;
+			}
+			if (r.ref == LUA_NOREF || lua_state() == nullptr) {
 				ref = LUA_NOREF;
 				return;
 			}
-			int p = r.push();
-			if (r.lua_state() != luastate) {
-				lua_xmove(r.lua_state(), L, p);
+			if (detail::xmovable(lua_state(), r.lua_state())) {
+				r.push(lua_state());
+				ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
+				return;
 			}
-			ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
+			ref = r.copy();
 		}
 		reference(lua_State* L, reference&& r) noexcept
 		: luastate(L) {
-			if (r.ref == LUA_NOREF) {
+			if (r.ref == LUA_REFNIL) {
+				ref = LUA_REFNIL;
+				return;
+			}
+			if (r.ref == LUA_NOREF || lua_state() == nullptr) {
 				ref = LUA_NOREF;
 				return;
 			}
-			if (r.lua_state() != luastate) {
-				int p = r.push();
-				lua_xmove(r.lua_state(), L, p);
+			if (detail::xmovable(lua_state(), r.lua_state())) {
+				r.push(lua_state());
 				ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
+				return;
 			}
-			else {
-				ref = r.ref;
-				r.luastate = nullptr;
-				r.ref = LUA_NOREF;
-			}
+			ref = r.ref;
+			r.ref = LUA_NOREF;
+			r.luastate = nullptr;
 		}
 		reference(lua_State* L, const stack_reference& r) noexcept
 		: luastate(L) {
-			if (!r.valid()) {
+			if (lua_state() == nullptr || r.lua_state() == nullptr || r.get_type() == type::none) {
 				ref = LUA_NOREF;
 				return;
 			}
-			r.push(luastate);
+			if (r.get_type() == type::nil) {
+				ref = LUA_REFNIL;
+				return;
+			}
+			if (lua_state() != r.lua_state() && !detail::xmovable(lua_state(), r.lua_state())) {
+				return;
+			}
+			r.push(lua_state());
 			ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
 		}
 		reference(lua_State* L, int index = -1) noexcept
@@ -5995,26 +6020,43 @@ namespace sol {
 			o.ref = LUA_NOREF;
 		}
 
-		reference& operator=(reference&& o) noexcept {
-			if (valid()) {
-				deref();
+		reference& operator=(reference&& r) noexcept {
+			if (r.ref == LUA_REFNIL) {
+				ref = LUA_REFNIL;
+				return *this;
+			}
+			if (r.ref == LUA_NOREF || lua_state() == nullptr) {
+				ref = LUA_NOREF;
+				return *this;
+			}
+			if (detail::xmovable(lua_state(), r.lua_state())) {
+				r.push(lua_state());
+				ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
+				return *this;
 			}
-			luastate = o.luastate;
-			ref = o.ref;
-
-			o.luastate = nullptr;
-			o.ref = LUA_NOREF;
 
+			ref = r.ref;
+			r.ref = LUA_NOREF;
+			r.luastate = nullptr;
 			return *this;
 		}
 
-		reference& operator=(const reference& o) noexcept {
-			if (valid()) {
-				deref();
+		reference& operator=(const reference& r) noexcept {
+			if (r.ref == LUA_REFNIL) {
+				ref = LUA_REFNIL;
+				return *this;
+			}
+			if (r.ref == LUA_NOREF || lua_state() == nullptr) {
+				ref = LUA_NOREF;
+				return *this;
+			}
+			if (detail::xmovable(lua_state(), r.lua_state())) {
+				r.push(lua_state());
+				ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
+				return *this;
 			}
 
-			luastate = o.luastate;
-			ref = o.copy();
+			ref = r.copy();
 			return *this;
 		}
 
diff --git a/sol/reference.hpp b/sol/reference.hpp
index 4ca91446..9ed8e815 100644
--- a/sol/reference.hpp
+++ b/sol/reference.hpp
@@ -167,40 +167,54 @@ namespace sol {
 		}
 		reference(lua_State* L, const reference& r) noexcept
 		: luastate(L) {
-			if (r.ref == LUA_NOREF) {
+			if (r.ref == LUA_REFNIL) {
+				ref = LUA_REFNIL;
+				return;
+			}
+			if (r.ref == LUA_NOREF || lua_state() == nullptr) {
 				ref = LUA_NOREF;
 				return;
 			}
-			int p = r.push();
-			if (r.lua_state() != luastate) {
-				lua_xmove(r.lua_state(), L, p);
+			if (detail::xmovable(lua_state(), r.lua_state())) {
+				r.push(lua_state());
+				ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
+				return;
 			}
-			ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
+			ref = r.copy();
 		}
 		reference(lua_State* L, reference&& r) noexcept
 		: luastate(L) {
-			if (r.ref == LUA_NOREF) {
+			if (r.ref == LUA_REFNIL) {
+				ref = LUA_REFNIL;
+				return;
+			}
+			if (r.ref == LUA_NOREF || lua_state() == nullptr) {
 				ref = LUA_NOREF;
 				return;
 			}
-			if (r.lua_state() != luastate) {
-				int p = r.push();
-				lua_xmove(r.lua_state(), L, p);
+			if (detail::xmovable(lua_state(), r.lua_state())) {
+				r.push(lua_state());
 				ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
+				return;
 			}
-			else {
-				ref = r.ref;
-				r.luastate = nullptr;
-				r.ref = LUA_NOREF;
-			}
+			ref = r.ref;
+			r.ref = LUA_NOREF;
+			r.luastate = nullptr;
 		}
 		reference(lua_State* L, const stack_reference& r) noexcept
 		: luastate(L) {
-			if (!r.valid()) {
+			if (lua_state() == nullptr || r.lua_state() == nullptr || r.get_type() == type::none) {
 				ref = LUA_NOREF;
 				return;
 			}
-			r.push(luastate);
+			if (r.get_type() == type::nil) {
+				ref = LUA_REFNIL;
+				return;
+			}
+			if (lua_state() != r.lua_state() && !detail::xmovable(lua_state(), r.lua_state())) {
+				return;
+			}
+			r.push(lua_state());
 			ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
 		}
 		reference(lua_State* L, int index = -1) noexcept
@@ -231,26 +245,43 @@ namespace sol {
 			o.ref = LUA_NOREF;
 		}
 
-		reference& operator=(reference&& o) noexcept {
-			if (valid()) {
-				deref();
+		reference& operator=(reference&& r) noexcept {
+			if (r.ref == LUA_REFNIL) {
+				ref = LUA_REFNIL;
+				return *this;
+			}
+			if (r.ref == LUA_NOREF || lua_state() == nullptr) {
+				ref = LUA_NOREF;
+				return *this;
+			}
+			if (detail::xmovable(lua_state(), r.lua_state())) {
+				r.push(lua_state());
+				ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
+				return *this;
 			}
-			luastate = o.luastate;
-			ref = o.ref;
-
-			o.luastate = nullptr;
-			o.ref = LUA_NOREF;
 
+			ref = r.ref;
+			r.ref = LUA_NOREF;
+			r.luastate = nullptr;
 			return *this;
 		}
 
-		reference& operator=(const reference& o) noexcept {
-			if (valid()) {
-				deref();
+		reference& operator=(const reference& r) noexcept {
+			if (r.ref == LUA_REFNIL) {
+				ref = LUA_REFNIL;
+				return *this;
+			}
+			if (r.ref == LUA_NOREF || lua_state() == nullptr) {
+				ref = LUA_NOREF;
+				return *this;
+			}
+			if (detail::xmovable(lua_state(), r.lua_state())) {
+				r.push(lua_state());
+				ref = luaL_ref(lua_state(), LUA_REGISTRYINDEX);
+				return *this;
 			}
 
-			luastate = o.luastate;
-			ref = o.copy();
+			ref = r.copy();
 			return *this;
 		}
 
diff --git a/sol/stack_reference.hpp b/sol/stack_reference.hpp
index 555d36d5..206238d1 100644
--- a/sol/stack_reference.hpp
+++ b/sol/stack_reference.hpp
@@ -25,6 +25,17 @@
 #include "types.hpp"
 
 namespace sol {
+	namespace detail {
+		inline bool xmovable(lua_State* leftL, lua_State* rightL) {
+			if (rightL == nullptr || leftL == nullptr || leftL == rightL) {
+				return false;
+			}
+			const void* leftregistry = lua_topointer(leftL, LUA_REGISTRYINDEX);
+			const void* rightregistry = lua_topointer(rightL, LUA_REGISTRYINDEX);
+			return leftregistry == rightregistry;
+		}
+	} // namespace detail
+
 	class stack_reference {
 	private:
 		lua_State* luastate = nullptr;
@@ -60,7 +71,7 @@ namespace sol {
 				return;
 			}
 			int i = r.stack_index();
-			if (r.lua_state() != luastate) {
+			if (detail::xmovable(lua_state(), r.lua_state())) {
 				lua_pushvalue(r.lua_state(), r.index);
 				lua_xmove(r.lua_state(), luastate, 1);
 				i = absolute_index(luastate, -1);
diff --git a/tests/test_coroutines.cpp b/tests/test_coroutines.cpp
index 117ce99e..782a77f1 100644
--- a/tests/test_coroutines.cpp
+++ b/tests/test_coroutines.cpp
@@ -108,7 +108,7 @@ end
 	}
 }
 
-TEST_CASE("coroutines/implicit transfer", "check that copy and move assignment constructors implicitly shift things around") {
+TEST_CASE("coroutines/explicit transfer", "check that the xmove constructors shift things around appropriately") {
 	const std::string code = R"(
 -- main thread - L1
 -- co - L2
@@ -194,3 +194,90 @@ co = nil
 	std::string s = t[1];
 	REQUIRE(s == "SOME_TABLE");
 }
+
+TEST_CASE("coroutines/implicit transfer", "check that copy and move assignment constructors implicitly shift things around") {
+	const std::string code = R"(
+-- main thread - L1
+-- co - L2
+-- co2 - L3
+
+x = co_test.new("x")
+local co = coroutine.wrap(
+	function()
+		local t = co_test.new("t")
+		local co2 = coroutine.wrap(
+			function()
+				local t2 = { "SOME_TABLE" }
+				t:copy_store(t2) -- t2 = [L3], t.obj = [L2]
+			end
+		)
+
+		co2()
+		co2 = nil
+
+		collectgarbage() -- t2 ref in t remains valid!
+
+		x:store(t:get()) -- t.obj = [L2], x.obj = [L1]
+    end
+)
+
+co()
+collectgarbage()
+collectgarbage()
+co = nil
+)";
+
+	struct co_test_implicit {
+		std::string identifier;
+		sol::reference obj;
+
+		co_test_implicit(sol::this_state L, std::string id) : identifier(id), obj(L, sol::lua_nil) {
+
+		}
+
+		void store(sol::table ref) {
+			// must be explicit
+			obj = std::move(ref);
+		}
+
+		void copy_store(sol::table ref) {
+			// must be explicit
+			obj = ref;
+		}
+
+		sol::reference get() {
+			return obj;
+		}
+
+		~co_test_implicit() {
+
+		}
+	};
+
+	sol::state lua;
+	lua.open_libraries(sol::lib::coroutine, sol::lib::base);
+
+	lua.new_usertype<co_test_implicit>("co_test",
+		sol::constructors<co_test_implicit(sol::this_state, std::string)>(),
+		"store", &co_test_implicit::store,
+		"copy_store", &co_test_implicit::copy_store,
+		"get", &co_test_implicit::get
+		);
+
+
+	auto r = lua.safe_script(code);
+	REQUIRE(r.valid());
+
+	co_test_implicit& ct = lua["x"];
+
+	lua_State* Lmain1 = lua.lua_state();
+	lua_State* Lmain2 = sol::main_thread(lua);
+	lua_State* Lmain3 = ct.get().lua_state();
+	REQUIRE(Lmain1 == Lmain2);
+	REQUIRE(Lmain1 == Lmain3);
+
+	sol::table t = ct.get();
+	REQUIRE(t.size() == 1);
+	std::string s = t[1];
+	REQUIRE(s == "SOME_TABLE");
+}