From 683d1393d7af32302e42444cea0af8c3dbc557d3 Mon Sep 17 00:00:00 2001 From: ThePhD Date: Tue, 27 Aug 2019 18:36:45 -0400 Subject: [PATCH] std function can be empty fixes #862 --- include/sol/function_types.hpp | 16 +- single/include/sol/forward.hpp | 4 +- single/include/sol/sol.hpp | 20 +- tests/runtime_tests/source/functions.cpp | 328 ++++--------------- tests/runtime_tests/source/functions.std.cpp | 229 +++++++++++++ 5 files changed, 323 insertions(+), 274 deletions(-) create mode 100644 tests/runtime_tests/source/functions.std.cpp diff --git a/include/sol/function_types.hpp b/include/sol/function_types.hpp index cad99097..763153ba 100644 --- a/include/sol/function_types.hpp +++ b/include/sol/function_types.hpp @@ -129,7 +129,7 @@ namespace sol { dFx memfxptr(std::forward(fx)); auto userptr = detail::ptr(std::forward(args)...); lua_CFunction freefunc - = &function_detail::upvalue_member_variable, meta::unqualified_t, is_yielding>::call; + = &function_detail::upvalue_member_variable, meta::unqualified_t, is_yielding>::call; int upvalues = 0; upvalues += stack::push(L, nullptr); @@ -296,13 +296,19 @@ namespace sol { template struct unqualified_pusher> { static int push(lua_State* L, const std::function& fx) { - function_detail::select(L, fx); - return 1; + if (fx) { + function_detail::select(L, fx); + return 1; + } + return stack::push(L, lua_nil); } static int push(lua_State* L, std::function&& fx) { - function_detail::select(L, std::move(fx)); - return 1; + if (fx) { + function_detail::select(L, std::move(fx)); + return 1; + } + return stack::push(L, lua_nil); } }; diff --git a/single/include/sol/forward.hpp b/single/include/sol/forward.hpp index c7fa132e..67d4d9c2 100644 --- a/single/include/sol/forward.hpp +++ b/single/include/sol/forward.hpp @@ -20,8 +20,8 @@ // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. // This file was generated with a script. -// Generated 2019-08-15 12:13:47.368988 UTC -// This header was generated with sol v3.0.3 (revision c3c08df) +// Generated 2019-08-27 22:31:46.530255 UTC +// This header was generated with sol v3.0.3 (revision 242990a) // https://github.com/ThePhD/sol2 #ifndef SOL_SINGLE_INCLUDE_FORWARD_HPP diff --git a/single/include/sol/sol.hpp b/single/include/sol/sol.hpp index 8c9402d9..526889fb 100644 --- a/single/include/sol/sol.hpp +++ b/single/include/sol/sol.hpp @@ -20,8 +20,8 @@ // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. // This file was generated with a script. -// Generated 2019-08-15 12:13:46.408546 UTC -// This header was generated with sol v3.0.3 (revision c3c08df) +// Generated 2019-08-27 22:31:43.964254 UTC +// This header was generated with sol v3.0.3 (revision 242990a) // https://github.com/ThePhD/sol2 #ifndef SOL_SINGLE_INCLUDE_HPP @@ -17736,7 +17736,7 @@ namespace sol { dFx memfxptr(std::forward(fx)); auto userptr = detail::ptr(std::forward(args)...); lua_CFunction freefunc - = &function_detail::upvalue_member_variable, meta::unqualified_t, is_yielding>::call; + = &function_detail::upvalue_member_variable, meta::unqualified_t, is_yielding>::call; int upvalues = 0; upvalues += stack::push(L, nullptr); @@ -17903,13 +17903,19 @@ namespace sol { template struct unqualified_pusher> { static int push(lua_State* L, const std::function& fx) { - function_detail::select(L, fx); - return 1; + if (fx) { + function_detail::select(L, fx); + return 1; + } + return stack::push(L, lua_nil); } static int push(lua_State* L, std::function&& fx) { - function_detail::select(L, std::move(fx)); - return 1; + if (fx) { + function_detail::select(L, std::move(fx)); + return 1; + } + return stack::push(L, lua_nil); } }; diff --git a/tests/runtime_tests/source/functions.cpp b/tests/runtime_tests/source/functions.cpp index d3fa61a8..c68bbd83 100644 --- a/tests/runtime_tests/source/functions.cpp +++ b/tests/runtime_tests/source/functions.cpp @@ -1,4 +1,4 @@ -// sol3 +// sol3 // The MIT License (MIT) @@ -41,18 +41,6 @@ T va_func(sol::variadic_args va, T first) { return s; } -std::function makefn() { - auto fx = []() -> int { - return 0x1456789; - }; - return fx; -} - -void takefn(std::function purr) { - if (purr() != 0x1456789) - throw 0; -} - struct A { int a = 0xA; int bark() { @@ -64,15 +52,6 @@ std::tuple bark(int num_value, A* a) { return std::tuple(num_value * 2, a->bark()); } -void test_free_func(std::function f) { - f(); -} - -void test_free_func2(std::function f, int arg1) { - int val = f(arg1); - REQUIRE(val == arg1); -} - int overloaded(int x) { INFO(x); return 3; @@ -191,11 +170,9 @@ TEST_CASE("functions/return order and multi get", "Check if return order is in t const static std::tuple paired = std::make_tuple(10, 10.f); sol::state lua; sol::stack_guard luasg(lua); - + lua.set_function("f", [] { return std::make_tuple(10, 11, 12); }); - lua.set_function("h", []() { - return std::make_tuple(10, 10.0f); - }); + lua.set_function("h", []() { return std::make_tuple(10, 10.0f); }); auto result1 = lua.safe_script("function g() return 10, 11, 12 end\nx,y,z = g()", sol::script_pass_on_error); REQUIRE(result1.valid()); @@ -221,9 +198,7 @@ TEST_CASE("functions/deducing return order and multi get", "Check if return orde f_string_result = f_string(); REQUIRE(f_string_result == "this is a string!"); - lua.set_function("f", [] { - return std::make_tuple(10, 11, 12); - }); + lua.set_function("f", [] { return std::make_tuple(10, 11, 12); }); auto result1 = lua.safe_script("function g() return 10, 11, 12 end\nx,y,z = g()", sol::script_pass_on_error); REQUIRE(result1.valid()); @@ -245,7 +220,8 @@ TEST_CASE("functions/optional values", "check if optionals can be passed in to b sol::state lua; auto result1 = lua.safe_script(R"( function f (a) return a -end )", sol::script_pass_on_error); +end )", + sol::script_pass_on_error); REQUIRE(result1.valid()); sol::function lua_bark = lua["f"]; @@ -256,14 +232,17 @@ end )", sol::script_pass_on_error); REQUIRE_FALSE((bool)testn); REQUIRE(testv.value() == 29); sol::optional v = lua_bark(sol::optional(thing{ 29 })); - REQUIRE_NOTHROW([&] {sol::lua_nil_t n = lua_bark(sol::nullopt); return n; }()); + REQUIRE_NOTHROW([&] { + sol::lua_nil_t n = lua_bark(sol::nullopt); + return n; + }()); REQUIRE(v->v == 29); } TEST_CASE("functions/pair and tuple and proxy tests", "Check if sol::reference and sol::proxy can be passed to functions as arguments") { sol::state lua; sol::stack_guard luasg(lua); - + lua.new_usertype("A", "bark", &A::bark); auto result1 = lua.safe_script(R"( function f (num_value, a) return num_value * 2, a:bark() @@ -271,7 +250,8 @@ end function h (num_value, a, b) return num_value * 2, a:bark(), b * 3 end -nested = { variables = { no = { problem = 10 } } } )", sol::script_pass_on_error); +nested = { variables = { no = { problem = 10 } } } )", + sol::script_pass_on_error); REQUIRE(result1.valid()); lua.set_function("g", bark); @@ -303,44 +283,8 @@ nested = { variables = { no = { problem = 10 } } } )", sol::script_pass_on_error REQUIRE(abc == abcdesired); } -TEST_CASE("functions/sol::function to std::function", "check if conversion to std::function works properly and calls with correct arguments") { - sol::state lua; - sol::stack_guard luasg(lua); - - lua.open_libraries(sol::lib::base); - - lua.set_function("testFunc", test_free_func); - lua.set_function("testFunc2", test_free_func2); - auto result1 = lua.safe_script("testFunc(function() print(\"hello std::function\") end)", sol::script_pass_on_error); - REQUIRE(result1.valid()); - { - auto result = lua.safe_script( - "function m(a)\n" - " print(\"hello std::function with arg \", a)\n" - " return a\n" - "end\n" - "\n" - "testFunc2(m, 1)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } -} - -TEST_CASE("functions/returning functions from C++", "check to see if returning a functor and getting a functor from lua is possible") { - sol::state lua; - lua.open_libraries(sol::lib::base); - - lua.set_function("makefn", makefn); - lua.set_function("takefn", takefn); - { - auto result = lua.safe_script( - "afx = makefn()\n" - "print(afx())\n" - "takefn(afx)\n", sol::script_pass_on_error); - REQUIRE(result.valid()); - } -} - -TEST_CASE("functions/function_result and protected_function_result", "Function result should be the beefy return type for sol::function that allows for error checking and error handlers") { +TEST_CASE("functions/function_result and protected_function_result", + "Function result should be the beefy return type for sol::function that allows for error checking and error handlers") { static const char unhandlederrormessage[] = "true error message"; static const char handlederrormessage[] = "doodle"; static const std::string handlederrormessage_s = handlederrormessage; @@ -350,9 +294,7 @@ TEST_CASE("functions/function_result and protected_function_result", "Function r lua.open_libraries(sol::lib::base, sol::lib::debug); // Some function; just using a lambda to be cheap - auto doomfx = []() { - throw std::runtime_error(unhandlederrormessage); - }; + auto doomfx = []() { throw std::runtime_error(unhandlederrormessage); }; lua.set_function("doom", doomfx); auto cpphandlerfx = [](std::string x) { @@ -360,20 +302,16 @@ TEST_CASE("functions/function_result and protected_function_result", "Function r return handlederrormessage; }; lua.set_function("cpphandler", cpphandlerfx); - - auto result1 = lua.safe_script( - std::string("function luahandler ( message )") - + " print('lua handler called with: ' .. message)" - + " return '" + handlederrormessage + "'" - + "end", sol::script_pass_on_error); + + auto result1 = lua.safe_script(std::string("function luahandler ( message )") + " print('lua handler called with: ' .. message)" + " return '" + + handlederrormessage + "'" + "end", + sol::script_pass_on_error); REQUIRE(result1.valid()); - - auto nontrampolinefx = [](lua_State* L) -> int { - return luaL_error(L, "x"); - }; + + auto nontrampolinefx = [](lua_State* L) -> int { return luaL_error(L, "x"); }; lua_CFunction c_nontrampolinefx = nontrampolinefx; lua.set("nontrampoline", c_nontrampolinefx); - + lua.set_function("bark", []() -> int { return 100; }); sol::function luahandler = lua["luahandler"]; @@ -435,8 +373,9 @@ TEST_CASE("functions/function_result and protected_function_result", "Function r } } -#if !defined(SOL2_CI) && !(SOL2_CI) && ((!defined(_M_IX86) || defined(_M_IA64)) || (defined(_WIN64)) || (defined(__LLP64__) || defined(__LP64__)) ) -TEST_CASE("functions/safe protected_function_result handlers", "These tests will (hopefully) not destroy the stack since they are supposed to be mildly safe. Still, run with caution.") { +#if !defined(SOL2_CI) && !(SOL2_CI) && ((!defined(_M_IX86) || defined(_M_IA64)) || (defined(_WIN64)) || (defined(__LLP64__) || defined(__LP64__))) +TEST_CASE("functions/safe protected_function_result handlers", + "These tests will (hopefully) not destroy the stack since they are supposed to be mildly safe. Still, run with caution.") { sol::state lua; lua.open_libraries(sol::lib::base, sol::lib::debug); static const char unhandlederrormessage[] = "true error message"; @@ -567,10 +506,8 @@ TEST_CASE("functions/all kinds", "Register all kinds of functions, make sure the auto c = [&]() { return 502; }; auto d = []() { return 503; }; - lua.new_usertype("test_1", - "bark", sol::c_call); - lua.new_usertype("test_2", - "bark", sol::c_call); + lua.new_usertype("test_1", "bark", sol::c_call); + lua.new_usertype("test_2", "bark", sol::c_call); test_2 t2; lua.set_function("a", a); @@ -590,7 +527,8 @@ TEST_CASE("functions/all kinds", "Register all kinds of functions, make sure the auto result1 = lua.safe_script(R"( o1 = test_1.new() o2 = test_2.new() -)", sol::script_pass_on_error); +)", + sol::script_pass_on_error); REQUIRE(result1.valid()); auto result2 = lua.safe_script(R"( @@ -605,46 +543,52 @@ G0, G1 = g(2, o1) H = h(o1) I = i(o1) I = i(o1) -)", sol::script_pass_on_error); +)", + sol::script_pass_on_error); REQUIRE(result2.valid()); auto result3 = lua.safe_script(R"( J0 = j() j(24) J1 = j() - )", sol::script_pass_on_error); + )", + sol::script_pass_on_error); REQUIRE(result3.valid()); auto result4 = lua.safe_script(R"( K0 = k(o2) k(o2, 1024) K1 = k(o2) - )", sol::script_pass_on_error); + )", + sol::script_pass_on_error); REQUIRE(result4.valid()); auto result5 = lua.safe_script(R"( L0 = l(o1) l(o1, 678) L1 = l(o1) - )", sol::script_pass_on_error); + )", + sol::script_pass_on_error); REQUIRE(result5.valid()); auto result6 = lua.safe_script(R"( M0 = m() m(256) M1 = m() - )", sol::script_pass_on_error); + )", + sol::script_pass_on_error); REQUIRE(result6.valid()); auto result7 = lua.safe_script(R"( N = n(1, 2, 3) - )", sol::script_pass_on_error); + )", + sol::script_pass_on_error); REQUIRE(result7.valid()); int ob, A, B, C, D, F, G0, G1, H, I, J0, J1, K0, K1, L0, L1, M0, M1, N; std::tie(ob, A, B, C, D, F, G0, G1, H, I, J0, J1, K0, K1, L0, L1, M0, M1, N) - = lua.get( - "ob", "A", "B", "C", "D", "F", "G0", "G1", "H", "I", "J0", "J1", "K0", "K1", "L0", "L1", "M0", "M1", "N"); + = lua.get( + "ob", "A", "B", "C", "D", "F", "G0", "G1", "H", "I", "J0", "J1", "K0", "K1", "L0", "L1", "M0", "M1", "N"); REQUIRE(ob == 0xA); @@ -674,8 +618,8 @@ N = n(1, 2, 3) REQUIRE(N == 13); sol::tie(ob, A, B, C, D, F, G0, G1, H, I, J0, J1, K0, K1, L0, L1, M0, M1, N) - = lua.get( - "ob", "A", "B", "C", "D", "F", "G0", "G1", "H", "I", "J0", "J1", "K0", "K1", "L0", "L1", "M0", "M1", "N"); + = lua.get( + "ob", "A", "B", "C", "D", "F", "G0", "G1", "H", "I", "J0", "J1", "K0", "K1", "L0", "L1", "M0", "M1", "N"); REQUIRE(ob == 0xA); @@ -756,9 +700,7 @@ TEST_CASE("simple/call with parameters", "Lua function is called with a few para REQUIRE(result.valid()); } auto fvoid = lua.get("my_nothing"); - REQUIRE_NOTHROW([&]() { - fvoid(1, 2, 3); - }()); + REQUIRE_NOTHROW([&]() { fvoid(1, 2, 3); }()); REQUIRE_NOTHROW([&]() { int a = f.call(1, 2, 3); REQUIRE(a == 6); @@ -908,7 +850,8 @@ TEST_CASE("functions/tie", "make sure advanced syntax with 'tie' works") { auto result1 = lua.safe_script(R"(function f () return 1, 2, 3 -end)", sol::script_pass_on_error); +end)", + sol::script_pass_on_error); REQUIRE(result1.valid()); sol::function f = lua["f"]; @@ -930,10 +873,11 @@ TEST_CASE("functions/overloading", "Check if overloading works properly for regu { auto result = lua.safe_script( - "a = func(1)\n" - "b = func('bark')\n" - "c = func(1,2)\n" - "func(1,2,3)\n", sol::script_pass_on_error); + "a = func(1)\n" + "b = func('bark')\n" + "c = func(1,2)\n" + "func(1,2,3)\n", + sol::script_pass_on_error); REQUIRE(result.valid()); } @@ -974,7 +918,7 @@ TEST_CASE("overloading/c_call", "Make sure that overloading works with c_call fu } TEST_CASE("functions/stack atomic", "make sure functions don't impede on the stack") { - //setup sol/lua + // setup sol/lua sol::state lua; lua.open_libraries(sol::lib::base, sol::lib::string); @@ -988,7 +932,7 @@ TEST_CASE("functions/stack atomic", "make sure functions don't impede on the sta INFO("Back in C++, direct call result is : " << str); } - //test protected_function + // test protected_function sol::protected_function Stringtest(lua["stringtest"]); Stringtest.error_handler = lua["ErrorHandler"]; sol::stack_guard luasg(lua); @@ -1000,7 +944,7 @@ TEST_CASE("functions/stack atomic", "make sure functions don't impede on the sta } REQUIRE(luasg.check_stack()); - //test optional + // test optional { sol::stack_guard opsg(lua); sol::optional opt_result = Stringtest("optional test"); @@ -1091,9 +1035,7 @@ TEST_CASE("functions/function_result as arguments", "ensure that function_result REQUIRE(c == 3); REQUIRE(d == 4); REQUIRE(e == 5); - REQUIRE_NOTHROW([&]() { - lua["g"](pf()); - }()); + REQUIRE_NOTHROW([&]() { lua["g"](pf()); }()); } double f = sol::stack::pop(lua); REQUIRE(f == 256.78); @@ -1120,9 +1062,7 @@ TEST_CASE("functions/protected_function_result as arguments", "ensure that prote REQUIRE(c == 3); REQUIRE(d == 4); REQUIRE(e == 5); - REQUIRE_NOTHROW([&]() { - lua["g"](pf()); - }()); + REQUIRE_NOTHROW([&]() { lua["g"](pf()); }()); } double f = sol::stack::pop(lua); REQUIRE(f == 256.78); @@ -1171,127 +1111,6 @@ TEST_CASE("functions/sectioning variadic", "make sure variadics can bite off chu lua.safe_script("print(x3) assert(x3 == 18)"); } -TEST_CASE("functions/set_function already wrapped", "setting a function returned from Lua code that is already wrapped into a sol::function or similar") { - SECTION("test different types") { - sol::state lua; - lua.open_libraries(sol::lib::base); - sol::function fn = lua.safe_script("return function() return 5 end"); - sol::protected_function pfn = fn; - std::function sfn = fn; - - lua.set_function("test", fn); - lua.set_function("test2", pfn); - lua.set_function("test3", sfn); - - { - auto result = lua.safe_script("assert(type(test) == 'function')", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("assert(test() ~= nil)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("assert(test() == 5)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - - { - auto result = lua.safe_script("assert(type(test2) == 'function')", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("assert(test2() ~= nil)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("assert(test2() == 5)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - - { - auto result = lua.safe_script("assert(type(test3) == 'function')", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("assert(test3() ~= nil)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("assert(test3() == 5)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - } - - SECTION("getting the value from C++") { - sol::state lua; - lua.open_libraries(sol::lib::base); - sol::function fn = lua.safe_script("return function() return 5 end"); - - int result = fn(); - REQUIRE(result == 5); - } - - SECTION("setting the function directly") { - sol::state lua; - lua.open_libraries(sol::lib::base); - sol::function fn = lua.safe_script("return function() return 5 end"); - - lua.set_function("test", fn); - - { - auto result = lua.safe_script("assert(type(test) == 'function')", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("assert(test() ~= nil)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("assert(test() == 5)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - } - - SECTION("does the function actually get executed?") { - sol::state lua; - lua.open_libraries(sol::lib::base); - - sol::function fn2 = lua.safe_script("return function() print('this was executed') end"); - lua.set_function("test", fn2); - - { - auto result = lua.safe_script("assert(type(test) == 'function')", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("test()", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - } - - SECTION("setting the function indirectly, with the return value cast explicitly") { - sol::state lua; - lua.open_libraries(sol::lib::base); - sol::function fn = lua.safe_script("return function() return 5 end"); - - lua.set_function("test", [&fn]() { return fn.call(); }); - - { - auto result = lua.safe_script("assert(type(test) == 'function')", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("assert(test() ~= nil)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - { - auto result = lua.safe_script("assert(test() == 5)", sol::script_pass_on_error); - REQUIRE(result.valid()); - } - } -} - TEST_CASE("functions/pointer nullptr + nil", "ensure specific semantics for handling pointer-nils passed through sol") { struct nil_test { @@ -1469,29 +1288,17 @@ TEST_CASE("functions/pointer nullptr + nil", "ensure specific semantics for hand TEST_CASE("functions/unique_usertype overloading", "make sure overloading can work with ptr vs. specifically asking for a unique_usertype") { struct test { int special_value = 17; - test() - : special_value(17) { + test() : special_value(17) { } - test(int special_value) - : special_value(special_value) { + test(int special_value) : special_value(special_value) { } }; - auto print_up_test = [](std::unique_ptr& x) { - REQUIRE(x->special_value == 21); - }; - auto print_up_2_test = [](int, std::unique_ptr& x) { - REQUIRE(x->special_value == 21); - }; - auto print_sp_test = [](std::shared_ptr& x) { - REQUIRE(x->special_value == 44); - }; - auto print_ptr_test = [](test* x) { - REQUIRE(x->special_value == 17); - }; + auto print_up_test = [](std::unique_ptr& x) { REQUIRE(x->special_value == 21); }; + auto print_up_2_test = [](int, std::unique_ptr& x) { REQUIRE(x->special_value == 21); }; + auto print_sp_test = [](std::shared_ptr& x) { REQUIRE(x->special_value == 44); }; + auto print_ptr_test = [](test* x) { REQUIRE(x->special_value == 17); }; auto print_ref_test = [](test& x) { - bool is_any = x.special_value == 17 - || x.special_value == 21 - || x.special_value == 44; + bool is_any = x.special_value == 17 || x.special_value == 21 || x.special_value == 44; REQUIRE(is_any); }; using f_t = void(test&); @@ -1615,7 +1422,8 @@ TEST_CASE("functions/lua style default arguments", "allow default arguments usin v1d, v1nd = f1(), f1(1) v2d, v2nd = f2(), f2(1) v3d, v3nd = f3(), f3(1) - )", sol::script_pass_on_error); + )", + sol::script_pass_on_error); REQUIRE(result.valid()); int v1d = lua["v1d"]; int v1nd = lua["v1nd"]; diff --git a/tests/runtime_tests/source/functions.std.cpp b/tests/runtime_tests/source/functions.std.cpp new file mode 100644 index 00000000..b5edf4aa --- /dev/null +++ b/tests/runtime_tests/source/functions.std.cpp @@ -0,0 +1,229 @@ +// sol3 + +// The MIT License (MIT) + +// Copyright (c) 2013-2019 Rapptz, ThePhD and contributors + +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include "sol_test.hpp" + +#include + +#include + +void test_free_func(std::function f) { + f(); +} + +void test_free_func2(std::function f, int arg1) { + int val = f(arg1); + REQUIRE(val == arg1); +} + +std::function makefn() { + auto fx = []() -> int { return 0x1456789; }; + return fx; +} + +void takefn(std::function purr) { + if (purr() != 0x1456789) + throw 0; +} + +TEST_CASE("functions/empty std functions", "std::function is allowed to be empty, so it should be serialized to nil") { + sol::state lua; + std::function foo = nullptr; + sol::function bar; + + lua["Foo"] = foo; + lua["Bar"] = bar; + + sol::optional result = lua.script(R"SCR( + if Bar ~= nil + then + Bar() + end + + if Foo ~= nil or type(Foo) ~= 'function' + then + Foo() + end + )SCR"); + REQUIRE_FALSE(result.has_value()); +} + +TEST_CASE("functions/sol::function to std::function", "check if conversion to std::function works properly and calls with correct arguments") { + sol::state lua; + sol::stack_guard luasg(lua); + + lua.open_libraries(sol::lib::base); + + lua.set_function("testFunc", test_free_func); + lua.set_function("testFunc2", test_free_func2); + auto result1 = lua.safe_script("testFunc(function() print(\"hello std::function\") end)", sol::script_pass_on_error); + REQUIRE(result1.valid()); + { + auto result = lua.safe_script( + "function m(a)\n" + " print(\"hello std::function with arg \", a)\n" + " return a\n" + "end\n" + "\n" + "testFunc2(m, 1)", + sol::script_pass_on_error); + REQUIRE(result.valid()); + } +} + +TEST_CASE("functions/returning functions from C++", "check to see if returning a functor and getting a functor from lua is possible") { + sol::state lua; + lua.open_libraries(sol::lib::base); + + lua.set_function("makefn", makefn); + lua.set_function("takefn", takefn); + { + auto result = lua.safe_script( + "afx = makefn()\n" + "print(afx())\n" + "takefn(afx)\n", + sol::script_pass_on_error); + REQUIRE(result.valid()); + } +} + +TEST_CASE("functions/set_function already wrapped", "setting a function returned from Lua code that is already wrapped into a sol::function or similar") { + SECTION("test different types") { + sol::state lua; + lua.open_libraries(sol::lib::base); + sol::function fn = lua.safe_script("return function() return 5 end"); + sol::protected_function pfn = fn; + std::function sfn = fn; + + lua.set_function("test", fn); + lua.set_function("test2", pfn); + lua.set_function("test3", sfn); + + { + auto result = lua.safe_script("assert(type(test) == 'function')", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("assert(test() ~= nil)", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("assert(test() == 5)", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + + { + auto result = lua.safe_script("assert(type(test2) == 'function')", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("assert(test2() ~= nil)", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("assert(test2() == 5)", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + + { + auto result = lua.safe_script("assert(type(test3) == 'function')", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("assert(test3() ~= nil)", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("assert(test3() == 5)", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + } + + SECTION("getting the value from C++") { + sol::state lua; + lua.open_libraries(sol::lib::base); + sol::function fn = lua.safe_script("return function() return 5 end"); + + int result = fn(); + REQUIRE(result == 5); + } + + SECTION("setting the function directly") { + sol::state lua; + lua.open_libraries(sol::lib::base); + sol::function fn = lua.safe_script("return function() return 5 end"); + + lua.set_function("test", fn); + + { + auto result = lua.safe_script("assert(type(test) == 'function')", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("assert(test() ~= nil)", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("assert(test() == 5)", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + } + + SECTION("does the function actually get executed?") { + sol::state lua; + lua.open_libraries(sol::lib::base); + + sol::function fn2 = lua.safe_script("return function() print('this was executed') end"); + lua.set_function("test", fn2); + + { + auto result = lua.safe_script("assert(type(test) == 'function')", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("test()", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + } + + SECTION("setting the function indirectly, with the return value cast explicitly") { + sol::state lua; + lua.open_libraries(sol::lib::base); + sol::function fn = lua.safe_script("return function() return 5 end"); + + lua.set_function("test", [&fn]() { return fn.call(); }); + + { + auto result = lua.safe_script("assert(type(test) == 'function')", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("assert(test() ~= nil)", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + { + auto result = lua.safe_script("assert(test() == 5)", sol::script_pass_on_error); + REQUIRE(result.valid()); + } + } +}