// sol3 // The MIT License (MIT) // Copyright (c) 2013-2020 Rapptz, ThePhD and contributors // Permission is hereby granted, free of charge, to any person obtaining a copy of // this software and associated documentation files (the "Software"), to deal in // the Software without restriction, including without limitation the rights to // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of // the Software, and to permit persons to whom the Software is furnished to do so, // subject to the following conditions: // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. #include "sol_test.hpp" #include #include inline namespace sol2_test_coroutines { struct coroutine_thread_runner_state { sol::state_view* prunner_thread_state; coroutine_thread_runner_state(sol::state_view& ref_) : prunner_thread_state(&ref_) { } auto operator()() const { sol::state_view& runner_thread_state = *prunner_thread_state; sol::coroutine cr = runner_thread_state["loop"]; sol::stack::push(runner_thread_state, 50); sol::stack::push(runner_thread_state, 25); int r = cr(); return r; } }; struct coroutine_thread_runner { sol::thread* prunner_thread; coroutine_thread_runner(sol::thread& ref_) : prunner_thread(&ref_) { } auto operator()() const { sol::thread& runner_thread = *prunner_thread; sol::state_view th_state = runner_thread.state(); sol::coroutine cr = th_state["loop"]; int r = cr(); return r; } }; struct coroutine_storage { coroutine_storage(sol::state& luaContext) { mLuaState = &luaContext; mThread = sol::thread::create(luaContext); sol::state_view luaThreadState = mThread.state(); mThreadEnvironment = sol::environment(luaThreadState, sol::create, luaThreadState.globals()); sol::set_environment(mThreadEnvironment, mThread); sol::optional actionTable = luaThreadState["aTable"]; if (actionTable) { mThreadTick = actionTable.value()["Tick"]; } } coroutine_storage& operator=(coroutine_storage&& r) noexcept { mThread = std::move(r.mThread); mThreadEnvironment = std::move(r.mThreadEnvironment); mThreadTick = std::move(r.mThreadTick); return *this; } coroutine_storage(coroutine_storage&& r) noexcept : mLuaState(std::move(r.mLuaState)) , mThread(std::move(r.mThread)) , mThreadEnvironment(std::move(r.mThreadEnvironment)) , mThreadTick(std::move(r.mThreadTick)) { } coroutine_storage& operator=(const coroutine_storage& r) = delete; coroutine_storage(const coroutine_storage& r) = delete; sol::state* mLuaState; sol::thread mThread; sol::environment mThreadEnvironment; sol::coroutine mThreadTick; }; struct coro_h { int x = 500; int func() { x += 1; return x; } }; struct coro_test { std::string identifier; sol::reference obj; coro_test(sol::this_state L, std::string id) : identifier(id), obj(L, sol::lua_nil) { } void store(sol::table ref) { // must be explicit obj = sol::reference(obj.lua_state(), ref); } void copy_store(sol::table ref) { // must be explicit obj = sol::reference(obj.lua_state(), ref); } sol::reference get() { return obj; } ~coro_test() { } }; struct coro_test_implicit { std::string identifier; sol::main_reference obj; coro_test_implicit(sol::this_state L, std::string id) : identifier(id), obj(L, sol::lua_nil) { } void store(sol::table ref) { // main_reference does the state shift implicitly obj = std::move(ref); lua_State* Lmain = sol::main_thread(ref.lua_state()); REQUIRE(obj.lua_state() == Lmain); } void copy_store(sol::table ref) { // main_reference does the state shift implicitly obj = ref; lua_State* Lmain = sol::main_thread(ref.lua_state()); REQUIRE(obj.lua_state() == Lmain); } sol::reference get() { return obj; } ~coro_test_implicit() { } }; } // namespace sol2_test_coroutines TEST_CASE("coroutines/coroutine.yield", "ensure calling a coroutine works") { const auto& script = R"(counter = 20 function loop() while counter ~= 30 do coroutine.yield(counter); counter = counter + 1; end return counter end )"; sol::state lua; sol::stack_guard luasg(lua); lua.open_libraries(sol::lib::base, sol::lib::coroutine); auto result1 = lua.safe_script(script); REQUIRE(result1.valid()); sol::coroutine cr = lua["loop"]; int counter; for (counter = 20; counter < 31 && cr; ++counter) { int value = cr(); REQUIRE(counter == value); } counter -= 1; REQUIRE(counter == 30); } TEST_CASE("coroutines/new thread coroutines", "ensure calling a coroutine works when the work is put on a different thread") { const auto& code = R"(counter = 20 function loop() while counter ~= 30 do coroutine.yield(counter); counter = counter + 1; end return counter end )"; sol::state lua; sol::stack_guard luasg(lua); lua.open_libraries(sol::lib::base, sol::lib::coroutine); auto result = lua.safe_script(code, sol::script_pass_on_error); REQUIRE(result.valid()); sol::thread runner = sol::thread::create(lua.lua_state()); sol::state_view runnerstate = runner.state(); sol::coroutine cr = runnerstate["loop"]; int counter; for (counter = 20; counter < 31 && cr; ++counter) { int value = cr(); REQUIRE(counter == value); } counter -= 1; REQUIRE(counter == 30); } TEST_CASE("coroutines/transfer", "test that things created inside of a coroutine can have their state transferred using lua_xmove constructors") { for (std::size_t tries = 0; tries < 200; ++tries) { sol::state lua; sol::stack_guard luasg(lua); lua.open_libraries(); { sol::function f2; lua["f"] = [&lua, &f2](sol::object t) { f2 = sol::function(lua, t); }; { auto code = R"( i = 0 function INIT() co = coroutine.create( function() local g = function() i = i + 1 end f(g) g = nil collectgarbage() end ) coroutine.resume(co) co = nil collectgarbage() end )"; auto result = lua.safe_script(code, sol::script_pass_on_error); REQUIRE(result.valid()); } sol::function f3; sol::function f1; { auto code = "INIT()"; auto result = lua.safe_script(code, sol::script_pass_on_error); REQUIRE(result.valid()); } f2(); auto updatecode = "return function() collectgarbage() end"; auto pfr = lua.safe_script(updatecode); REQUIRE(pfr.valid()); sol::function update = pfr.get(); update(); f3 = f2; f3(); update(); f1 = f2; f1(); update(); int i = lua["i"]; REQUIRE(i == 3); } } } TEST_CASE("coroutines/explicit transfer", "check that the xmove constructors shift things around appropriately") { const std::string code = R"( -- main thread - L1 -- co - L2 -- co2 - L3 x = coro_test.new("x") local co = coroutine.wrap( function() local t = coro_test.new("t") local co2 = coroutine.wrap( function() local t2 = { "SOME_TABLE" } t:copy_store(t2) -- t2 = [L3], t.obj = [L2] end ) co2() co2 = nil collectgarbage() -- t2 ref in t remains valid! x:store(t:get()) -- t.obj = [L2], x.obj = [L1] end ) co() collectgarbage() collectgarbage() co = nil )"; sol::state lua; lua.open_libraries(sol::lib::coroutine, sol::lib::base); lua.new_usertype("coro_test", sol::constructors(), "store", &coro_test::store, "copy_store", &coro_test::copy_store, "get", &coro_test::get); auto r = lua.safe_script(code, sol::script_pass_on_error); REQUIRE(r.valid()); coro_test& ct = lua["x"]; lua_State* Lmain1 = lua.lua_state(); lua_State* Lmain2 = sol::main_thread(lua); lua_State* Lmain3 = ct.get().lua_state(); REQUIRE(Lmain1 == Lmain2); REQUIRE(Lmain1 == Lmain3); sol::table t = ct.get(); REQUIRE(t.size() == 1); std::string s = t[1]; REQUIRE(s == "SOME_TABLE"); } TEST_CASE("coroutines/implicit transfer", "check that copy and move assignment constructors implicitly shift things around") { const std::string code = R"( -- main thread - L1 -- co - L2 -- co2 - L3 x = coro_test.new("x") local co = coroutine.wrap( function() local t = coro_test.new("t") local co2 = coroutine.wrap( function() local t2 = { "SOME_TABLE" } t:copy_store(t2) -- t2 = [L3], t.obj = [L2] end ) co2() co2 = nil collectgarbage() -- t2 ref in t remains valid! x:store(t:get()) -- t.obj = [L2], x.obj = [L1] end ) co() collectgarbage() collectgarbage() co = nil )"; struct coro_test_implicit { std::string identifier; sol::reference obj; coro_test_implicit(sol::this_state L, std::string id) : identifier(id), obj(L, sol::lua_nil) { } void store(sol::table ref) { // must be explicit obj = std::move(ref); } void copy_store(sol::table ref) { // must be explicit obj = ref; } sol::reference get() { return obj; } ~coro_test_implicit() { } }; sol::state lua; lua.open_libraries(sol::lib::coroutine, sol::lib::base); lua.new_usertype("coro_test", sol::constructors(), "store", &coro_test_implicit::store, "copy_store", &coro_test_implicit::copy_store, "get", &coro_test_implicit::get); auto r = lua.safe_script(code, sol::script_pass_on_error); REQUIRE(r.valid()); coro_test_implicit& ct = lua["x"]; lua_State* Lmain1 = lua.lua_state(); lua_State* Lmain2 = sol::main_thread(lua); lua_State* Lmain3 = ct.get().lua_state(); REQUIRE(Lmain1 == Lmain2); REQUIRE(Lmain1 == Lmain3); sol::table t = ct.get(); REQUIRE(t.size() == 1); std::string s = t[1]; REQUIRE(s == "SOME_TABLE"); } TEST_CASE("coroutines/main transfer", "check that copy and move assignment constructors using main-forced types work") { const std::string code = R"( -- main thread - L1 -- co - L2 -- co2 - L3 x = coro_test.new("x") local co = coroutine.wrap( function() local t = coro_test.new("t") local co2 = coroutine.wrap( function() local t2 = { "SOME_TABLE" } t:copy_store(t2) -- t2 = [L3], t.obj = [L2] end ) co2() co2 = nil collectgarbage() -- t2 ref in t remains valid! x:store(t:get()) -- t.obj = [L2], x.obj = [L1] end ) co() co = nil collectgarbage() )"; sol::state lua; lua.open_libraries(sol::lib::coroutine, sol::lib::base); lua.new_usertype("coro_test", sol::constructors(), "store", &coro_test_implicit::store, "copy_store", &coro_test_implicit::copy_store, "get", &coro_test_implicit::get); auto r = lua.safe_script(code, sol::script_pass_on_error); REQUIRE(r.valid()); coro_test_implicit& ct = lua["x"]; lua_State* Lmain1 = lua.lua_state(); lua_State* Lmain2 = sol::main_thread(lua); lua_State* Lmain3 = ct.get().lua_state(); REQUIRE(Lmain1 == Lmain2); REQUIRE(Lmain1 == Lmain3); sol::table t = ct.get(); REQUIRE(t.size() == 1); std::string s = t[1]; REQUIRE(s == "SOME_TABLE"); } TEST_CASE("coroutines/coroutine.create protection", "ensure that a thread picked up from coroutine.create does not throw off the lua stack entirely when called from C++") { sol::state lua; lua.open_libraries(sol::lib::base, sol::lib::coroutine); auto code = R"( function loop() local i = 0 while true do print("pre-yield in loop") coroutine.yield(i) print("post-yield in loop") i = i+1 end end loop_th = coroutine.create(loop) )"; auto r = lua.safe_script(code, sol::script_pass_on_error); REQUIRE(r.valid()); sol::thread runner_thread = lua["loop_th"]; coroutine_thread_runner test_resume(runner_thread); lua.set_function("test_resume", std::ref(test_resume)); int v0 = test_resume(); int v1 = test_resume(); int v2, v3; { auto r2 = lua.safe_script("return test_resume()", sol::script_pass_on_error); REQUIRE(r2.valid()); auto r3 = lua.safe_script("return test_resume()", sol::script_pass_on_error); REQUIRE(r3.valid()); v2 = r2; v3 = r3; } REQUIRE(v0 == 0); REQUIRE(v1 == 1); REQUIRE(v2 == 2); REQUIRE(v3 == 3); } TEST_CASE("coroutines/stack-check", "check that resumed functions consume the entire execution stack") { sol::state lua; lua.open_libraries(sol::lib::base, sol::lib::table, sol::lib::coroutine); { auto code = R"( unpack = unpack or table.unpack function loop() local i = 0 while true do print("pre-yield in loop") coroutine.yield(i) print("post-yield in loop") i = i+1 end end loop_th = coroutine.create(loop) loop_res = function(...) returns = { coroutine.resume(loop_th, ...) } return unpack(returns, 2) end )"; auto result = lua.safe_script(code, sol::script_pass_on_error); REQUIRE(result.valid()); } // Resume from lua via thread and coroutine sol::thread runner_thread = lua["loop_th"]; sol::state_view runner_thread_state = runner_thread.state(); coroutine_thread_runner_state test_resume(runner_thread_state); lua.set_function("test_resume", std::ref(test_resume)); // Resume via getting a sol::function from the state sol::function test_resume_lua = lua["loop_res"]; // Resume via passing a sol::function object auto test_resume_func = [](sol::function f) { int r = f(); return r; }; lua.set_function("test_resume_func", std::ref(test_resume_func)); int v0 = test_resume(); int s0 = runner_thread_state.stack_top(); int v1 = test_resume(); int s1 = runner_thread_state.stack_top(); int v2; { auto result = lua.safe_script("return test_resume()", sol::script_pass_on_error); REQUIRE(result.valid()); v2 = result; } int s2 = runner_thread_state.stack_top(); int v3; { auto result = lua.safe_script("return test_resume()", sol::script_pass_on_error); REQUIRE(result.valid()); v3 = result; } int s3 = runner_thread_state.stack_top(); int v4 = test_resume_lua(); int s4 = runner_thread_state.stack_top(); int v5; { auto result = lua.safe_script("return test_resume_func(loop_res)", sol::script_pass_on_error); REQUIRE(result.valid()); v5 = result; } int s5 = runner_thread_state.stack_top(); REQUIRE(v0 == 0); REQUIRE(v1 == 1); REQUIRE(v2 == 2); REQUIRE(v3 == 3); REQUIRE(v4 == 4); REQUIRE(v5 == 5); REQUIRE(s0 == 0); REQUIRE(s1 == 0); REQUIRE(s2 == 0); REQUIRE(s3 == 0); REQUIRE(s4 == 0); REQUIRE(s5 == 0); } TEST_CASE("coroutines/yielding", "test that a sol3 bound function can yield when marked yieldable") { SECTION("regular functions") { sol::state lua; lua.open_libraries(sol::lib::base, sol::lib::coroutine); int i = 0; auto func = [&i]() { ++i; return i; }; coro_h hobj {}; lua["f"] = sol::yielding(func); lua["g"] = sol::yielding([]() { return 300; }); lua["h"] = sol::yielding(&coro_h::func); lua["hobj"] = &hobj; sol::string_view code = R"( co1 = coroutine.create(function () return f() end) success1, value1 = coroutine.resume(co1) co2 = coroutine.create(function () return g() end) success2, value2 = coroutine.resume(co2) co3 = coroutine.create(function() h(hobj) end) success3, value3 = coroutine.resume(co3) )"; auto result = lua.safe_script(code, sol::script_pass_on_error); REQUIRE(result.valid()); bool success1 = lua["success1"]; int value1 = lua["value1"]; REQUIRE(success1); REQUIRE(value1 == 1); bool success2 = lua["success2"]; int value2 = lua["value2"]; REQUIRE(success2); REQUIRE(value2 == 300); bool success3 = lua["success3"]; int value3 = lua["value3"]; REQUIRE(success3); REQUIRE(value3 == 501); REQUIRE(hobj.x == 501); } SECTION("usertypes") { sol::state lua; lua.open_libraries(sol::lib::base, sol::lib::coroutine); coro_h hobj; lua["hobj"] = &hobj; lua.new_usertype("coro_h", "h", sol::yielding(&coro_h::func)); sol::string_view code = R"( co4 = coroutine.create(function() hobj:h() hobj.h(hobj) coro_h.h(hobj) end) success4, value4 = coroutine.resume(co4) success5, value5 = coroutine.resume(co4) success6, value6 = coroutine.resume(co4) )"; auto result = lua.safe_script(code, sol::script_pass_on_error); REQUIRE(result.valid()); bool success4 = lua["success4"]; int value4 = lua["value4"]; REQUIRE(success4); REQUIRE(value4 == 501); bool success5 = lua["success5"]; int value5 = lua["value5"]; REQUIRE(success5); REQUIRE(value5 == 502); bool success6 = lua["success6"]; int value6 = lua["value6"]; REQUIRE(success6); REQUIRE(value6 == 503); REQUIRE(hobj.x == 503); } } TEST_CASE("coroutines/error_handler_state_transfer", "test that sol3 coroutines with their error handlers are properly sourced") { sol::state lua; lua.open_libraries(sol::lib::base, sol::lib::coroutine); sol::optional maybe_result = lua.safe_script(R"( aTable = {} aTable["Tick"] = function() coroutine.yield() end )"); REQUIRE_FALSE(maybe_result.has_value()); int begintop = 0, endtop = 0; { test_stack_guard g(lua.lua_state(), begintop, endtop); std::vector luaCoroutines; luaCoroutines.emplace_back(lua); luaCoroutines.emplace_back(lua); luaCoroutines[0] = std::move(luaCoroutines.back()); luaCoroutines.pop_back(); } }