diff --git a/tests.cpp b/tests.cpp index 89de8659..c6ac269a 100644 --- a/tests.cpp +++ b/tests.cpp @@ -50,6 +50,27 @@ int plop_xyz(int x, int y, std::string z) { return 11; } +class Base { +public: + Base(int a_num) : m_num(a_num) { } + + int get_num() { + return m_num; + } + +protected: + int m_num; +}; + +class Derived : public Base { +public: + Derived(int a_num) : Base(a_num) { } + + int get_num_10() { + return 10 * m_num; + } +}; + TEST_CASE("simple/set_global", "Check if the set_global works properly.") { sol::state lua; @@ -254,7 +275,7 @@ TEST_CASE("tables/functions_variables", "Check if tables and function calls work std::cout << "stateless lambda()" << std::endl; return "test"; } - ); + ); REQUIRE_NOTHROW(run_script(lua)); lua.get("os").set_function("fun", &free_function); @@ -272,7 +293,7 @@ TEST_CASE("tables/functions_variables", "Check if tables and function calls work std::cout << "stateless lambda()" << std::endl; return "test"; } - ); + ); REQUIRE_NOTHROW(run_script(lua)); // r-value, cannot optimise @@ -430,8 +451,8 @@ TEST_CASE("tables/userdata utility", "Show internal management of classes regist lua.new_userdata("fuser", "add", &fuser::add, "add2", &fuser::add2); lua.script("a = fuser.new()\n" - "b = a:add(1)\n" - "c = a:add2(1)\n"); + "b = a:add(1)\n" + "c = a:add2(1)\n"); sol::object a = lua.get("a"); sol::object b = lua.get("b"); @@ -448,3 +469,27 @@ TEST_CASE("tables/userdata utility", "Show internal management of classes regist REQUIRE(bresult == 1); REQUIRE(cresult == 3); } + +TEST_CASE("tables/userdata utility derived", "userdata classes must play nice when a derived class does not overload a publically visible base function") { + sol::state lua; + lua.open_libraries(sol::lib::base); + sol::constructors> basector; + sol::userdata baseuserdata("Base", basector, "get_num", &Base::get_num); + + lua.set_userdata(baseuserdata); + + lua.script("base = Base.new(5)"); + lua.script("print(base:get_num())"); + + sol::constructors> derivedctor; + sol::userdata deriveduserdata("Derived", derivedctor, "get_num", &Derived::get_num, "get_num_10", &Derived::get_num_10); + + lua.set_userdata(deriveduserdata); + + lua.script("derived = Derived.new(7)"); + lua.script("dgn10 = derived:get_num_10()\nprint(dgn10)"); + lua.script("dgn = derived:get_num()\nprint(dgn)"); + + REQUIRE((lua.get("dgn10") == 70)); + REQUIRE((lua.get("dgn") == 7)); +}