add find method to containers to make fast lookups into set/unordered_set

This commit is contained in:
Carlos Carrasco 2017-05-03 22:31:43 +02:00 committed by The Phantom Derpstorm
parent fc3e7c40f3
commit e13711ed84
3 changed files with 56 additions and 3 deletions

View File

@ -76,7 +76,7 @@ Based on the type pushed, a few additional functions are added as "member functi
* ``my_container:clear()``: This will call the underlying containers ``clear`` function.
* ``my_container:add( key, value )`` or ``my_container:add( value )``: this will add to the end of the container, or if it is an associative or ordered container, simply put in an expected key-value pair into it.
* ``my_contaner:insert( where, value )`` or ``my_contaner:insert( key, value )``: similar to add, but it only takes two arguments. In the case of ``std::vector`` and the like, the first argument is a ``where`` integer index. The second argument is the value. For associative containers, a key and value argument are expected.
* ``my_container:find( value )``: This will call the underlying containers ``find`` function if it exists, or in case of associative containers, it will work just like an index call. This is meant to give a fast membership check for ``std::set`` and ``std::unordered_set`` containers.
.. _container-detection:

View File

@ -34,7 +34,7 @@ namespace sol {
typedef std::array<char, 1> one;
typedef std::array<char, 2> two;
template <typename C> static one test(decltype(&C::find));
template <typename C> static one test(decltype(std::declval<C>().find(std::declval<std::add_rvalue_reference_t<typename C::value_type>>()))*);
template <typename C> static two test(...);
public:
@ -159,6 +159,9 @@ namespace sol {
else if (name == "clear") {
return stack::push(L, &clear_call);
}
else if (name == "find") {
return stack::push(L, &find_call);
}
}
}
return stack::push(L, lua_nil);
@ -191,6 +194,9 @@ namespace sol {
else if (name == "clear") {
return stack::push(L, &clear_call);
}
else if (name == "find") {
return stack::push(L, &find_call);
}
}
}
@ -392,6 +398,36 @@ namespace sol {
return real_clear_call_capable(std::integral_constant<bool, detail::has_clear<T>::value>(), L);
}
static int real_find_call_capable(std::false_type, std::false_type, lua_State*L) {
static const std::string& s = detail::demangle<T>();
return luaL_error(L, "sol: cannot call find on type %s", s.c_str());
}
static int real_find_call_capable(std::false_type, std::true_type, lua_State*L) {
return real_index_call(L);
}
static int real_find_call_capable(std::true_type, std::false_type, lua_State* L) {
auto k = stack::check_get<V>(L, 2);
if (k) {
auto& src = get_src(L);
auto it = src.find(*k);
if (it != src.end()) {
auto& v = *it;
return stack::push_reference(L, v);
}
}
return stack::push(L, lua_nil);
}
static int real_find_call_capable(std::true_type, std::true_type, lua_State* L) {
return real_index_call(L);
}
static int real_find_call(lua_State*L) {
return real_find_call_capable(std::integral_constant<bool, detail::has_find<T>::value>(), is_associative(), L);
}
static int add_call(lua_State*L) {
return detail::static_trampoline<(&real_add_call)>(L);
}
@ -404,6 +440,10 @@ namespace sol {
return detail::static_trampoline<(&real_clear_call)>(L);
}
static int find_call(lua_State*L) {
return detail::static_trampoline<(&real_find_call)>(L);
}
static int length_call(lua_State*L) {
return detail::static_trampoline<(&real_length_call)>(L);
}
@ -430,7 +470,7 @@ namespace sol {
template <typename T>
inline auto container_metatable() {
typedef container_usertype_metatable<std::remove_pointer_t<T>> meta_cumt;
std::array<luaL_Reg, 10> reg = { {
std::array<luaL_Reg, 11> reg = { {
{ "__index", &meta_cumt::index_call },
{ "__newindex", &meta_cumt::new_index_call },
{ "__pairs", &meta_cumt::pairs_call },
@ -439,6 +479,7 @@ namespace sol {
{ "clear", &meta_cumt::clear_call },
{ "insert", &meta_cumt::insert_call },
{ "add", &meta_cumt::add_call },
{ "find", &meta_cumt::find_call },
std::is_pointer<T>::value ? luaL_Reg{ nullptr, nullptr } : luaL_Reg{ "__gc", &detail::usertype_alloc_destroy<T> },
{ nullptr, nullptr }
} };

View File

@ -289,6 +289,11 @@ end
function i (x)
x:clear()
end
function sf (x,v)
return x:find(v)
end
)");
// Have the function we
@ -296,6 +301,7 @@ end
sol::function g = lua["g"];
sol::function h = lua["h"];
sol::function i = lua["i"];
sol::function sf = lua["sf"];
// Set a global variable called
// "arr" to be a vector of 5 lements
@ -316,6 +322,12 @@ end
REQUIRE(map.size() == 6);
REQUIRE(set.size() == 6);
int r = sf(set, 8);
REQUIRE(r == 8);
sol::object rn = sf(set, 9);
REQUIRE(rn == sol::nil);
i(lua["arr"]);
i(lua["map"]);
i(lua["set"]);