diff --git a/docs/source/api/containers.rst b/docs/source/api/containers.rst index 7110d443..36821a34 100644 --- a/docs/source/api/containers.rst +++ b/docs/source/api/containers.rst @@ -76,7 +76,7 @@ Based on the type pushed, a few additional functions are added as "member functi * ``my_container:clear()``: This will call the underlying containers ``clear`` function. * ``my_container:add( key, value )`` or ``my_container:add( value )``: this will add to the end of the container, or if it is an associative or ordered container, simply put in an expected key-value pair into it. * ``my_contaner:insert( where, value )`` or ``my_contaner:insert( key, value )``: similar to add, but it only takes two arguments. In the case of ``std::vector`` and the like, the first argument is a ``where`` integer index. The second argument is the value. For associative containers, a key and value argument are expected. - +* ``my_container:find( value )``: This will call the underlying containers ``find`` function if it exists, or in case of associative containers, it will work just like an index call. This is meant to give a fast membership check for ``std::set`` and ``std::unordered_set`` containers. .. _container-detection: diff --git a/sol/container_usertype_metatable.hpp b/sol/container_usertype_metatable.hpp index 80f4d8ce..8aec6e36 100644 --- a/sol/container_usertype_metatable.hpp +++ b/sol/container_usertype_metatable.hpp @@ -34,7 +34,7 @@ namespace sol { typedef std::array one; typedef std::array two; - template static one test(decltype(&C::find)); + template static one test(decltype(std::declval().find(std::declval>()))*); template static two test(...); public: @@ -159,6 +159,9 @@ namespace sol { else if (name == "clear") { return stack::push(L, &clear_call); } + else if (name == "find") { + return stack::push(L, &find_call); + } } } return stack::push(L, lua_nil); @@ -191,6 +194,9 @@ namespace sol { else if (name == "clear") { return stack::push(L, &clear_call); } + else if (name == "find") { + return stack::push(L, &find_call); + } } } @@ -392,6 +398,36 @@ namespace sol { return real_clear_call_capable(std::integral_constant::value>(), L); } + static int real_find_call_capable(std::false_type, std::false_type, lua_State*L) { + static const std::string& s = detail::demangle(); + return luaL_error(L, "sol: cannot call find on type %s", s.c_str()); + } + + static int real_find_call_capable(std::false_type, std::true_type, lua_State*L) { + return real_index_call(L); + } + + static int real_find_call_capable(std::true_type, std::false_type, lua_State* L) { + auto k = stack::check_get(L, 2); + if (k) { + auto& src = get_src(L); + auto it = src.find(*k); + if (it != src.end()) { + auto& v = *it; + return stack::push_reference(L, v); + } + } + return stack::push(L, lua_nil); + } + + static int real_find_call_capable(std::true_type, std::true_type, lua_State* L) { + return real_index_call(L); + } + + static int real_find_call(lua_State*L) { + return real_find_call_capable(std::integral_constant::value>(), is_associative(), L); + } + static int add_call(lua_State*L) { return detail::static_trampoline<(&real_add_call)>(L); } @@ -404,6 +440,10 @@ namespace sol { return detail::static_trampoline<(&real_clear_call)>(L); } + static int find_call(lua_State*L) { + return detail::static_trampoline<(&real_find_call)>(L); + } + static int length_call(lua_State*L) { return detail::static_trampoline<(&real_length_call)>(L); } @@ -430,7 +470,7 @@ namespace sol { template inline auto container_metatable() { typedef container_usertype_metatable> meta_cumt; - std::array reg = { { + std::array reg = { { { "__index", &meta_cumt::index_call }, { "__newindex", &meta_cumt::new_index_call }, { "__pairs", &meta_cumt::pairs_call }, @@ -439,6 +479,7 @@ namespace sol { { "clear", &meta_cumt::clear_call }, { "insert", &meta_cumt::insert_call }, { "add", &meta_cumt::add_call }, + { "find", &meta_cumt::find_call }, std::is_pointer::value ? luaL_Reg{ nullptr, nullptr } : luaL_Reg{ "__gc", &detail::usertype_alloc_destroy }, { nullptr, nullptr } } }; diff --git a/test_containers.cpp b/test_containers.cpp index 6ba36641..4bddbc7f 100644 --- a/test_containers.cpp +++ b/test_containers.cpp @@ -289,6 +289,11 @@ end function i (x) x:clear() end + +function sf (x,v) + return x:find(v) +end + )"); // Have the function we @@ -296,6 +301,7 @@ end sol::function g = lua["g"]; sol::function h = lua["h"]; sol::function i = lua["i"]; + sol::function sf = lua["sf"]; // Set a global variable called // "arr" to be a vector of 5 lements @@ -316,6 +322,12 @@ end REQUIRE(map.size() == 6); REQUIRE(set.size() == 6); + int r = sf(set, 8); + REQUIRE(r == 8); + + sol::object rn = sf(set, 9); + REQUIRE(rn == sol::nil); + i(lua["arr"]); i(lua["map"]); i(lua["set"]);