From 96f231a183410eb29bf8c4843c7bb88fd2cca2f0 Mon Sep 17 00:00:00 2001 From: ThePhD Date: Wed, 18 May 2016 21:29:17 -0400 Subject: [PATCH] We do have a cheap char type now, but figuring out that codecvt is super busted makes me sad Safety macros are now in place. Usertype ones turned on by default Closes #93 Closes #94 --- sol/compatibility/version.hpp | 15 ++++ sol/function_types_overload.hpp | 7 +- sol/function_types_usertype.hpp | 12 +++- sol/stack_check_get.hpp | 2 +- sol/stack_get.hpp | 83 +++++++++++++++++++++ sol/stack_push.hpp | 124 ++++++++++++++++++++++++++++++++ sol/types.hpp | 36 ++++++++++ test_strings.cpp | 71 ++++++++++++++++++ tests.cpp | 16 +++++ 9 files changed, 363 insertions(+), 3 deletions(-) create mode 100644 test_strings.cpp diff --git a/sol/compatibility/version.hpp b/sol/compatibility/version.hpp index db26f6e8..8d590413 100644 --- a/sol/compatibility/version.hpp +++ b/sol/compatibility/version.hpp @@ -44,10 +44,15 @@ #ifdef _MSC_VER #ifdef _DEBUG +#ifndef NDEBUG #ifndef SOL_CHECK_ARGUMENTS // Do not define by default: let user turn it on //#define SOL_CHECK_ARGUMENTS #endif // Check Arguments +#ifndef SOL_SAFE_USERTYPE +#define SOL_SAFE_USERTYPE +#endif // Safe Usertypes +#endif // NDEBUG #endif // Debug #ifndef _CPPUNWIND @@ -69,7 +74,11 @@ #ifndef SOL_CHECK_ARGUMENTS // Do not define by default: let user choose //#define SOL_CHECK_ARGUMENTS +// But do check userdata by default: #endif // Check Arguments +#ifndef SOL_SAFE_USERTYPE +#define SOL_SAFE_USERTYPE +#endif // Safe Usertypes #endif // g++ optimizer flag #endif // Not Debug @@ -87,4 +96,10 @@ #endif // vc++ || clang++/g++ +#ifndef SOL_SAFE_USERTYPE +#ifdef SOL_CHECK_ARGUMENTS +#define SOL_SAFE_USERTYPE +#endif // Turn on Safety for all +#endif // Safe Usertypes + #endif // SOL_VERSION_HPP diff --git a/sol/function_types_overload.hpp b/sol/function_types_overload.hpp index 655b79e3..ca8d77b1 100644 --- a/sol/function_types_overload.hpp +++ b/sol/function_types_overload.hpp @@ -119,7 +119,12 @@ struct usertype_overloaded_function : base_function { template > = 0> int call(types, index_value, types r, types a, lua_State* L, int, int start) { auto& func = std::get(overloads); - func.item = detail::ptr(stack::get(L, 1)); + func.item = stack::get*>(L, 1); +#ifdef SOL_SAFE_USERTYPE + if (func.item == nullptr) { + return luaL_error(L, "sol: received null for 'self' argument (use ':' for accessing member functions)"); + } +#endif // Safety return stack::call_into_lua<0, false>(r, a, L, start, func); } diff --git a/sol/function_types_usertype.hpp b/sol/function_types_usertype.hpp index 8c0d6f43..4845549f 100644 --- a/sol/function_types_usertype.hpp +++ b/sol/function_types_usertype.hpp @@ -60,7 +60,12 @@ struct usertype_function : public usertype_function_core { usertype_function(Args&&... args): base_t(std::forward(args)...) {} int prelude(lua_State* L) { - this->fx.item = detail::ptr(stack::get(L, 1)); + this->fx.item = stack::get*>(L, 1); +#ifdef SOL_SAFE_USERTYPE + if (this->fx.item == nullptr) { + return luaL_error(L, "sol: received null for 'self' argument (use ':' for accessing member functions)"); + } +#endif // Safety return static_cast(*this)(meta::tuple_types(), args_type(), index_value<2>(), L); } @@ -122,6 +127,11 @@ struct usertype_variable_function : public usertype_function_core int prelude(lua_State* L) { int argcount = lua_gettop(L); this->fx.item = stack::get(L, 1); +#ifdef SOL_SAFE_USERTYPE + if (this->fx.item == nullptr) { + return luaL_error(L, "sol: received null for 'self' argument (use ':' for accessing member functions)"); + } +#endif // Safety switch(argcount) { case 2: return get_variable(can_read(), L); diff --git a/sol/stack_check_get.hpp b/sol/stack_check_get.hpp index 30cae43e..a98e98d0 100644 --- a/sol/stack_check_get.hpp +++ b/sol/stack_check_get.hpp @@ -51,7 +51,7 @@ struct check_getter> { }; template -struct check_getter::value && !std::is_same::value>> { +struct check_getter::value && !lua_type_of::value == type::number>> { template static optional get( lua_State* L, int index, Handler&& handler) { int isnum = 0; diff --git a/sol/stack_get.hpp b/sol/stack_get.hpp index 10ee66c9..47a8d852 100644 --- a/sol/stack_get.hpp +++ b/sol/stack_get.hpp @@ -30,6 +30,7 @@ #include #include #include +#include namespace sol { namespace stack { @@ -113,6 +114,88 @@ struct getter { } }; +template<> +struct getter { + static char get(lua_State* L, int index = -1) { + size_t len; + auto str = lua_tolstring(L, index, &len); + return len > 0 ? str[0] : '\0'; + } +}; + +template<> +struct getter { + static std::wstring get(lua_State* L, int index = -1) { + size_t len; + auto str = lua_tolstring(L, index, &len); + typedef std::codecvt_utf8 convert; + std::wstring_convert conv; + return conv.from_bytes(str, str + len); + } +}; + +template<> +struct getter { + static std::u16string get(lua_State* L, int index = -1) { + size_t len; + auto str = lua_tolstring(L, index, &len); +#ifdef _MSC_VER // https://connect.microsoft.com/VisualStudio/feedback/details/1348277/link-error-when-using-std-codecvt-utf8-utf16-char16-t + typedef uint16_t T; + typedef std::codecvt_utf8 convert; + std::wstring_convert conv; + std::basic_string shitty = conv.from_bytes(str, str + len); + return std::u16string(shitty.cbegin(), shitty.cend()); // fuck you VC++ +#else + typedef std::codecvt_utf8 convert; + std::wstring_convert conv; + return conv.from_bytes(str, str + len); +#endif // VC++ + } +}; + +template<> +struct getter { + static std::u32string get(lua_State* L, int index = -1) { + size_t len; + auto str = lua_tolstring(L, index, &len); +#ifdef _MSC_VER // https://connect.microsoft.com/VisualStudio/feedback/details/1348277/link-error-when-using-std-codecvt-utf8-utf16-char16-t + typedef int32_t T; + typedef std::codecvt_utf8 convert; + std::wstring_convert conv; + std::basic_string shitty = conv.from_bytes(str, str + len); + return std::u32string(shitty.cbegin(), shitty.cend()); // fuck you VC++ +#else + typedef std::codecvt_utf8 convert; + std::wstring_convert conv; + return conv.from_bytes(str, str + len); +#endif // VC++ + } +}; + +template<> +struct getter { + static wchar_t get(lua_State* L, int index = -1) { + auto str = getter{}.get(L, index); + return str.size() > 0 ? str[0] : '\0'; + } +}; + +template<> +struct getter { + static char get(lua_State* L, int index = -1) { + auto str = getter{}.get(L, index); + return str.size() > 0 ? str[0] : '\0'; + } +}; + +template<> +struct getter { + static char32_t get(lua_State* L, int index = -1) { + auto str = getter{}.get(L, index); + return str.size() > 0 ? str[0] : '\0'; + } +}; + template<> struct getter { static nil_t get(lua_State*, int = -1) { diff --git a/sol/stack_push.hpp b/sol/stack_push.hpp index 9fcb6cca..2536f285 100644 --- a/sol/stack_push.hpp +++ b/sol/stack_push.hpp @@ -26,6 +26,7 @@ #include "raii.hpp" #include "optional.hpp" #include +#include namespace sol { namespace stack { @@ -252,6 +253,55 @@ struct pusher { } }; +template<> +struct pusher { + static int push(lua_State* L, const wchar_t* wstr) { + return push(L, wstr, wstr + std::char_traits::length(wstr)); + } + static int push(lua_State* L, const wchar_t* wstrb, const wchar_t* wstre) { + typedef std::codecvt_utf8 convert; + std::wstring_convert conv; + std::string str = conv.to_bytes( wstrb, wstre ); + return stack::push(L, str); + } +}; + +template<> +struct pusher { + static int push(lua_State* L, const char16_t* u16str) { + return push(L, u16str, u16str + std::char_traits::length(u16str)); + } + static int push(lua_State* L, const char16_t* u16strb, const char16_t* u16stre) { +#ifdef _MSC_VER // https://connect.microsoft.com/VisualStudio/feedback/details/1348277/link-error-when-using-std-codecvt-utf8-utf16-char16-t + typedef uint16_t T; +#else + typedef char16_t T; +#endif // VC++ + typedef std::codecvt_utf8 convert; + std::wstring_convert conv; + std::string str = conv.to_bytes( reinterpret_cast(u16strb), reinterpret_cast(u16stre) ); + return stack::push(L, str); + } +}; + +template<> +struct pusher { + static int push(lua_State* L, const char32_t* u32str) { + return push(L, u32str, u32str + std::char_traits::length(u32str)); + } + static int push(lua_State* L, const char32_t* u32strb, const char32_t* u32stre) { +#ifdef _MSC_VER // https://connect.microsoft.com/VisualStudio/feedback/details/1348277/link-error-when-using-std-codecvt-utf8-utf16-char16-t + typedef uint32_t T; +#else + typedef char32_t T; +#endif // VC++ + typedef std::codecvt_utf8 convert; + std::wstring_convert conv; + std::string str = conv.to_bytes( reinterpret_cast(u32strb), reinterpret_cast(u32stre) ); + return stack::push(L, str); + } +}; + template struct pusher { static int push(lua_State* L, const char (&str)[N]) { @@ -260,6 +310,59 @@ struct pusher { } }; +template +struct pusher { + static int push(lua_State* L, const wchar_t (&str)[N]) { + return stack::push(L, str, str + N - 1); + } +}; + +template +struct pusher { + static int push(lua_State* L, const char16_t (&str)[N]) { + return stack::push(L, str, str + N - 1); + } +}; + +template +struct pusher { + static int push(lua_State* L, const char32_t (&str)[N]) { + return stack::push(L, str, str + N - 1); + } +}; + +template <> +struct pusher { + static int push(lua_State* L, char c) { + const char str[2] = { c, '\0'}; + return stack::push(L, str); + } +}; + +template <> +struct pusher { + static int push(lua_State* L, wchar_t c) { + const wchar_t str[2] = { c, '\0'}; + return stack::push(L, str); + } +}; + +template <> +struct pusher { + static int push(lua_State* L, char16_t c) { + const char16_t str[2] = { c, '\0'}; + return stack::push(L, str); + } +}; + +template <> +struct pusher { + static int push(lua_State* L, char32_t c) { + const char32_t str[2] = { c, '\0'}; + return stack::push(L, str); + } +}; + template<> struct pusher { static int push(lua_State* L, const std::string& str) { @@ -268,6 +371,27 @@ struct pusher { } }; +template<> +struct pusher { + static int push(lua_State* L, const std::wstring& wstr) { + return stack::push(L, wstr.data(), wstr.data() + wstr.size()); + } +}; + +template<> +struct pusher { + static int push(lua_State* L, const std::u16string& u16str) { + return stack::push(L, u16str.data(), u16str.data() + u16str.size()); + } +}; + +template<> +struct pusher { + static int push(lua_State* L, const std::u32string& u32str) { + return stack::push(L, u32str.data(), u32str.data() + u32str.size()); + } +}; + template struct pusher> { template diff --git a/sol/types.hpp b/sol/types.hpp index ccda2d58..94fd2409 100644 --- a/sol/types.hpp +++ b/sol/types.hpp @@ -379,12 +379,48 @@ struct lua_type_of : std::integral_constant {}; template <> struct lua_type_of : std::integral_constant {}; +template <> +struct lua_type_of : std::integral_constant {}; + +template <> +struct lua_type_of : std::integral_constant {}; + +template <> +struct lua_type_of : std::integral_constant {}; + template struct lua_type_of : std::integral_constant {}; +template +struct lua_type_of : std::integral_constant {}; + +template +struct lua_type_of : std::integral_constant {}; + +template +struct lua_type_of : std::integral_constant {}; + +template <> +struct lua_type_of : std::integral_constant {}; + +template <> +struct lua_type_of : std::integral_constant {}; + +template <> +struct lua_type_of : std::integral_constant {}; + +template <> +struct lua_type_of : std::integral_constant {}; + template <> struct lua_type_of : std::integral_constant {}; +template <> +struct lua_type_of : std::integral_constant {}; + +template <> +struct lua_type_of : std::integral_constant {}; + template <> struct lua_type_of : std::integral_constant {}; diff --git a/test_strings.cpp b/test_strings.cpp new file mode 100644 index 00000000..4241aa9e --- /dev/null +++ b/test_strings.cpp @@ -0,0 +1,71 @@ +#define SOL_CHECK_ARGUMENTS + +#include +#include + +// There isn't a single library roundtripping with codecvt works on. We'll do the nitty-gritty of it later... +#if 0 +TEST_CASE("stack/strings", "test that strings can be roundtripped") { + sol::state lua; + + static const char utf8str[] = "\xF0\x9F\x8D\x8C\x20\xE6\x99\xA5\x20\x46\x6F\x6F\x20\xC2\xA9\x20\x62\x61\x72\x20\xF0\x9D\x8C\x86\x20\x62\x61\x7A\x20\xE2\x98\x83\x20\x71\x75\x78"; + static const char16_t utf16str[] = { 0xD83C, 0xDF4C, 0x20, 0x6665, 0x20, 0x46, 0x6F, 0x6F, 0x20, 0xA9, 0x20, 0x62, 0x61, 0x72, 0x20, 0xD834, 0xDF06, 0x20, 0x62, 0x61, 0x7A, 0x20, 0x2603, 0x20, 0x71, 0x75, 0x78, 0x00 }; + static const char32_t utf32str[] = { 0x1F34C, 0x0020, 0x6665, 0x0020, 0x0046, 0x006F, 0x006F, 0x0020, 0x00A9, 0x0020, 0x0062, 0x0061, 0x0072, 0x0020, 0x1D306, 0x0020, 0x0062, 0x0061, 0x007A, 0x0020, 0x2603, 0x0020, 0x0071, 0x0075, 0x0078, 0x00 }; + static const wchar_t widestr[] = L"Fuck these shitty compilers"; + + lua["utf8"] = utf8str; + lua["utf16"] = utf16str; + lua["utf32"] = utf32str; + lua["wide"] = widestr; + + std::string utf8_to_utf8 = lua["utf8"]; + std::string utf16_to_utf8 = lua["utf16"]; + std::string utf32_to_utf8 = lua["utf32"]; + std::string wide_to_utf8 = lua["wide"]; + + std::wstring utf8_to_wide = lua["utf8"]; + std::wstring utf16_to_wide = lua["utf16"]; + std::wstring utf32_to_wide = lua["utf32"]; + std::wstring wide_to_wide = lua["wide"]; + + std::u16string utf8_to_utf16 = lua["utf8"]; + std::u16string utf16_to_utf16 = lua["utf16"]; + std::u16string utf32_to_utf16 = lua["utf32"]; + std::u16string wide_to_utf16 = lua["wide"]; + + std::u32string utf8_to_utf32 = lua["utf8"]; + std::u32string utf16_to_utf32 = lua["utf16"]; + std::u32string utf32_to_utf32 = lua["utf32"]; + std::u32string wide_to_utf32 = lua["wide"]; + + REQUIRE(utf8_to_utf8 == utf8str); + REQUIRE(utf16_to_utf8 == utf8str); + REQUIRE(utf32_to_utf8 == utf8str); + REQUIRE(wide_to_utf8 == utf8str); + + REQUIRE(utf8_to_utf16 == utf16str); + REQUIRE(utf16_to_utf16 == utf16str); + REQUIRE(utf32_to_utf16 == utf16str); + REQUIRE(wide_to_utf16 == utf16str); + + REQUIRE(utf8_to_utf32 == utf32str); + REQUIRE(utf16_to_utf32 == utf32str); + REQUIRE(utf32_to_utf32 == utf32str); + REQUIRE(wide_to_utf32 == utf32str); + + REQUIRE(utf8_to_wide == widestr); + REQUIRE(utf16_to_wide == widestr); + REQUIRE(utf32_to_wide == widestr); + REQUIRE(wide_to_wide == widestr); + + char32_t utf8_to_char32 = lua["utf8"]; + char32_t utf16_to_char32 = lua["utf16"]; + char32_t utf32_to_char32 = lua["utf32"]; + char32_t wide_to_char32 = lua["wide"]; + + REQUIRE(utf8_to_char32 == utf32str[0]); + REQUIRE(utf16_to_char32 == utf32str[0]); + REQUIRE(utf32_to_char32 == utf32str[0]); + REQUIRE(wide_to_char32 == utf32str[0]); +} +#endif // Shit C++ diff --git a/tests.cpp b/tests.cpp index 588ceec2..12a477d4 100644 --- a/tests.cpp +++ b/tests.cpp @@ -1186,3 +1186,19 @@ TEST_CASE("object/conversions", "make sure all basic reference types can be made REQUIRE(os.get_type() == sol::type::string); REQUIRE(omn.get_type() == sol::type::nil); } + +TEST_CASE("usertype/safety", "crash with an exception -- not a segfault -- on bad userdata calls") { + class Test { + public: + void sayHello() { std::cout << "Hey\n"; } + }; + + sol::state lua; + lua.new_usertype("Test", "sayHello", &Test::sayHello); + static const std::string code = R"( + local t = Test.new() + t:sayHello() --Works fine + t.sayHello() --Uh oh. + )"; + REQUIRE_THROWS(lua.script(code)); +}