diff --git a/irr/include/vector3d.h b/irr/include/vector3d.h index fd788e734..562efb2d6 100644 --- a/irr/include/vector3d.h +++ b/irr/include/vector3d.h @@ -7,6 +7,7 @@ #include "irrMath.h" #include +#include namespace irr { @@ -32,6 +33,9 @@ public: //! Constructor with the same value for all elements explicit constexpr vector3d(T n) : X(n), Y(n), Z(n) {} + //! Array - vector conversion + constexpr vector3d(const std::array &arr) : + X(arr[0]), Y(arr[1]), Z(arr[2]) {} template constexpr static vector3d from(const vector3d &other) @@ -187,6 +191,8 @@ public: return *this; } + std::array toArray() const { return {X, Y, Z}; } + //! Get length of the vector. T getLength() const { return core::squareroot(X * X + Y * Y + Z * Z); } diff --git a/src/activeobjectmgr.h b/src/activeobjectmgr.h index a9b007018..952d812ef 100644 --- a/src/activeobjectmgr.h +++ b/src/activeobjectmgr.h @@ -39,7 +39,7 @@ public: for (auto &it : m_active_objects.iter()) { if (!it.second) continue; - m_active_objects.remove(it.first); + removeObject(it.first); } } while (!m_active_objects.empty()); } diff --git a/src/benchmark/benchmark_activeobjectmgr.cpp b/src/benchmark/benchmark_activeobjectmgr.cpp index d9036c632..23a712af6 100644 --- a/src/benchmark/benchmark_activeobjectmgr.cpp +++ b/src/benchmark/benchmark_activeobjectmgr.cpp @@ -105,7 +105,11 @@ void benchGetObjectsInArea(Catch::Benchmark::Chronometer &meter) TEST_CASE("ActiveObjectMgr") { BENCH_INSIDE_RADIUS(200) BENCH_INSIDE_RADIUS(1450) + BENCH_INSIDE_RADIUS(10000) BENCH_IN_AREA(200) BENCH_IN_AREA(1450) + BENCH_IN_AREA(10000) } + +// TODO benchmark active object manager update costs diff --git a/src/collision.cpp b/src/collision.cpp index 8f9cc788c..53d97553b 100644 --- a/src/collision.cpp +++ b/src/collision.cpp @@ -4,10 +4,12 @@ #include "collision.h" #include +#include "irr_aabb3d.h" #include "mapblock.h" #include "map.h" #include "nodedef.h" #include "gamedef.h" +#include "util/numeric.h" #if CHECK_CLIENT_BUILD() #include "client/clientenvironment.h" #include "client/localplayer.h" @@ -311,13 +313,14 @@ static void add_object_boxes(Environment *env, } }; - // Calculate distance by speed, add own extent and 1.5m of tolerance - const f32 distance = speed_f.getLength() * dtime + - box_0.getExtent().getLength() + 1.5f * BS; + constexpr f32 tolerance = 1.5f * BS; #if CHECK_CLIENT_BUILD() ClientEnvironment *c_env = dynamic_cast(env); if (c_env) { + // Calculate distance by speed, add own extent and tolerance + const f32 distance = speed_f.getLength() * dtime + + box_0.getExtent().getLength() + tolerance; std::vector clientobjects; c_env->getActiveObjects(pos_f, distance, clientobjects); @@ -356,9 +359,14 @@ static void add_object_boxes(Environment *env, return false; }; + // Calculate distance by speed, add own extent and tolerance + const v3f movement = speed_f * dtime; + const v3f min = pos_f + box_0.MinEdge - v3f(tolerance) + componentwise_min(movement, v3f()); + const v3f max = pos_f + box_0.MaxEdge + v3f(tolerance) + componentwise_max(movement, v3f()); + // nothing is put into this vector std::vector s_objects; - s_env->getObjectsInsideRadius(s_objects, pos_f, distance, include_obj_cb); + s_env->getObjectsInArea(s_objects, aabb3f(min, max), include_obj_cb); } } } diff --git a/src/server/activeobjectmgr.cpp b/src/server/activeobjectmgr.cpp index 155cf50fb..452017786 100644 --- a/src/server/activeobjectmgr.cpp +++ b/src/server/activeobjectmgr.cpp @@ -26,7 +26,7 @@ void ActiveObjectMgr::clearIf(const std::function obj) return false; } - if (objectpos_over_limit(obj->getBasePosition())) { - v3f p = obj->getBasePosition(); + const v3f pos = obj->getBasePosition(); + if (objectpos_over_limit(pos)) { warningstream << "Server::ActiveObjectMgr::addActiveObjectRaw(): " - << "object position (" << p.X << "," << p.Y << "," << p.Z + << "object position (" << pos.X << "," << pos.Y << "," << pos.Z << ") outside maximum range" << std::endl; return false; } auto obj_id = obj->getId(); m_active_objects.put(obj_id, std::move(obj)); + m_spatial_index.insert(pos.toArray(), obj_id); auto new_size = m_active_objects.size(); verbosestream << "Server::ActiveObjectMgr::addActiveObjectRaw(): " @@ -100,6 +101,8 @@ void ActiveObjectMgr::removeObject(u16 id) if (!ok) { infostream << "Server::ActiveObjectMgr::removeObject(): " << "id=" << id << " not found" << std::endl; + } else { + m_spatial_index.remove(id); } } @@ -113,43 +116,47 @@ void ActiveObjectMgr::invalidateActiveObjectObserverCaches() } } -void ActiveObjectMgr::getObjectsInsideRadius(const v3f &pos, float radius, +void ActiveObjectMgr::updateObjectPos(u16 id, v3f pos) +{ + // HACK defensively only update if we already know the object, + // otherwise we're still waiting to be inserted into the index + // (or have already been removed). + if (m_active_objects.get(id)) + m_spatial_index.update(pos.toArray(), id); +} + +void ActiveObjectMgr::getObjectsInsideRadius(v3f pos, float radius, std::vector &result, std::function include_obj_cb) { - float r2 = radius * radius; - for (auto &activeObject : m_active_objects.iter()) { - ServerActiveObject *obj = activeObject.second.get(); - if (!obj) - continue; - const v3f &objectpos = obj->getBasePosition(); - if (objectpos.getDistanceFromSQ(pos) > r2) - continue; + float r_squared = radius * radius; + m_spatial_index.rangeQuery((pos - v3f(radius)).toArray(), (pos + v3f(radius)).toArray(), [&](auto objPos, u16 id) { + if (v3f(objPos).getDistanceFromSQ(pos) > r_squared) + return; + auto obj = m_active_objects.get(id).get(); + if (!obj) + return; if (!include_obj_cb || include_obj_cb(obj)) result.push_back(obj); - } + }); } void ActiveObjectMgr::getObjectsInArea(const aabb3f &box, std::vector &result, std::function include_obj_cb) { - for (auto &activeObject : m_active_objects.iter()) { - ServerActiveObject *obj = activeObject.second.get(); + m_spatial_index.rangeQuery(box.MinEdge.toArray(), box.MaxEdge.toArray(), [&](auto _, u16 id) { + auto obj = m_active_objects.get(id).get(); if (!obj) - continue; - const v3f &objectpos = obj->getBasePosition(); - if (!box.isPointInside(objectpos)) - continue; - + return; if (!include_obj_cb || include_obj_cb(obj)) result.push_back(obj); - } + }); } void ActiveObjectMgr::getAddedActiveObjectsAroundPos( - const v3f &player_pos, const std::string &player_name, + v3f player_pos, const std::string &player_name, f32 radius, f32 player_radius, const std::set ¤t_objects, std::vector &added_objects) diff --git a/src/server/activeobjectmgr.h b/src/server/activeobjectmgr.h index 854a75b18..9c65ad514 100644 --- a/src/server/activeobjectmgr.h +++ b/src/server/activeobjectmgr.h @@ -8,6 +8,7 @@ #include #include "../activeobjectmgr.h" #include "serveractiveobject.h" +#include "util/k_d_tree.h" namespace server { @@ -25,16 +26,21 @@ public: void invalidateActiveObjectObserverCaches(); - void getObjectsInsideRadius(const v3f &pos, float radius, + void updateObjectPos(u16 id, v3f pos); + + void getObjectsInsideRadius(v3f pos, float radius, std::vector &result, std::function include_obj_cb); void getObjectsInArea(const aabb3f &box, std::vector &result, std::function include_obj_cb); void getAddedActiveObjectsAroundPos( - const v3f &player_pos, const std::string &player_name, + v3f player_pos, const std::string &player_name, f32 radius, f32 player_radius, const std::set ¤t_objects, std::vector &added_objects); + +private: + k_d_tree::DynamicKdTrees<3, f32, u16> m_spatial_index; }; } // namespace server diff --git a/src/server/luaentity_sao.cpp b/src/server/luaentity_sao.cpp index 5de0167d6..0ad3daba6 100644 --- a/src/server/luaentity_sao.cpp +++ b/src/server/luaentity_sao.cpp @@ -147,7 +147,7 @@ void LuaEntitySAO::step(float dtime, bool send_recommended) // Each frame, parent position is copied if the object is attached, otherwise it's calculated normally // If the object gets detached this comes into effect automatically from the last known origin if (auto *parent = getParent()) { - m_base_position = parent->getBasePosition(); + setBasePosition(parent->getBasePosition()); m_velocity = v3f(0,0,0); m_acceleration = v3f(0,0,0); } else { @@ -155,7 +155,7 @@ void LuaEntitySAO::step(float dtime, bool send_recommended) aabb3f box = m_prop.collisionbox; box.MinEdge *= BS; box.MaxEdge *= BS; - v3f p_pos = m_base_position; + v3f p_pos = getBasePosition(); v3f p_velocity = m_velocity; v3f p_acceleration = m_acceleration; moveresult = collisionMoveSimple(m_env, m_env->getGameDef(), @@ -165,11 +165,11 @@ void LuaEntitySAO::step(float dtime, bool send_recommended) moveresult_p = &moveresult; // Apply results - m_base_position = p_pos; + setBasePosition(p_pos); m_velocity = p_velocity; m_acceleration = p_acceleration; } else { - m_base_position += (m_velocity + m_acceleration * 0.5f * dtime) * dtime; + addPos((m_velocity + m_acceleration * 0.5f * dtime) * dtime); m_velocity += dtime * m_acceleration; } @@ -212,7 +212,7 @@ void LuaEntitySAO::step(float dtime, bool send_recommended) } else if(m_last_sent_position_timer > 0.2){ minchange = 0.05*BS; } - float move_d = m_base_position.getDistanceFrom(m_last_sent_position); + float move_d = getBasePosition().getDistanceFrom(m_last_sent_position); move_d += m_last_sent_move_precision; float vel_d = m_velocity.getDistanceFrom(m_last_sent_velocity); if (move_d > minchange || vel_d > minchange || @@ -236,7 +236,7 @@ std::string LuaEntitySAO::getClientInitializationData(u16 protocol_version) os << serializeString16(m_init_name); // name writeU8(os, 0); // is_player writeU16(os, getId()); //id - writeV3F32(os, m_base_position); + writeV3F32(os, getBasePosition()); writeV3F32(os, m_rotation); writeU16(os, m_hp); @@ -365,7 +365,7 @@ void LuaEntitySAO::setPos(const v3f &pos) { if(isAttached()) return; - m_base_position = pos; + setBasePosition(pos); sendPosition(false, true); } @@ -373,7 +373,7 @@ void LuaEntitySAO::moveTo(v3f pos, bool continuous) { if(isAttached()) return; - m_base_position = pos; + setBasePosition(pos); if(!continuous) sendPosition(true, true); } @@ -387,7 +387,7 @@ std::string LuaEntitySAO::getDescription() { std::ostringstream oss; oss << "LuaEntitySAO \"" << m_init_name << "\" "; - auto pos = floatToInt(m_base_position, BS); + auto pos = floatToInt(getBasePosition(), BS); oss << "at " << pos; return oss.str(); } @@ -503,10 +503,10 @@ void LuaEntitySAO::sendPosition(bool do_interpolate, bool is_movement_end) // Send attachment updates instantly to the client prior updating position sendOutdatedData(); - m_last_sent_move_precision = m_base_position.getDistanceFrom( + m_last_sent_move_precision = getBasePosition().getDistanceFrom( m_last_sent_position); m_last_sent_position_timer = 0; - m_last_sent_position = m_base_position; + m_last_sent_position = getBasePosition(); m_last_sent_velocity = m_velocity; //m_last_sent_acceleration = m_acceleration; m_last_sent_rotation = m_rotation; @@ -514,7 +514,7 @@ void LuaEntitySAO::sendPosition(bool do_interpolate, bool is_movement_end) float update_interval = m_env->getSendRecommendedInterval(); std::string str = generateUpdatePositionCommand( - m_base_position, + getBasePosition(), m_velocity, m_acceleration, m_rotation, @@ -534,8 +534,8 @@ bool LuaEntitySAO::getCollisionBox(aabb3f *toset) const toset->MinEdge = m_prop.collisionbox.MinEdge * BS; toset->MaxEdge = m_prop.collisionbox.MaxEdge * BS; - toset->MinEdge += m_base_position; - toset->MaxEdge += m_base_position; + toset->MinEdge += getBasePosition(); + toset->MaxEdge += getBasePosition(); return true; } diff --git a/src/server/player_sao.cpp b/src/server/player_sao.cpp index 11fc15597..068b2b29f 100644 --- a/src/server/player_sao.cpp +++ b/src/server/player_sao.cpp @@ -70,11 +70,10 @@ std::string PlayerSAO::getDescription() void PlayerSAO::addedToEnvironment(u32 dtime_s) { ServerActiveObject::addedToEnvironment(dtime_s); - ServerActiveObject::setBasePosition(m_base_position); m_player->setPlayerSAO(this); m_player->setPeerId(m_peer_id_initial); m_peer_id_initial = PEER_ID_INEXISTENT; // don't try to use it again. - m_last_good_position = m_base_position; + m_last_good_position = getBasePosition(); } // Called before removing from environment @@ -100,7 +99,7 @@ std::string PlayerSAO::getClientInitializationData(u16 protocol_version) os << serializeString16(m_player->getName()); // name writeU8(os, 1); // is_player writeS16(os, getId()); // id - writeV3F32(os, m_base_position); + writeV3F32(os, getBasePosition()); writeV3F32(os, m_rotation); writeU16(os, getHP()); @@ -184,7 +183,7 @@ void PlayerSAO::step(float dtime, bool send_recommended) // Sequence of damage points, starting 0.1 above feet and progressing // upwards in 1 node intervals, stopping below top damage point. for (float dam_height = 0.1f; dam_height < dam_top; dam_height++) { - v3s16 p = floatToInt(m_base_position + + v3s16 p = floatToInt(getBasePosition() + v3f(0.0f, dam_height * BS, 0.0f), BS); MapNode n = m_env->getMap().getNode(p); const ContentFeatures &c = m_env->getGameDef()->ndef()->get(n); @@ -196,7 +195,7 @@ void PlayerSAO::step(float dtime, bool send_recommended) } // Top damage point - v3s16 ptop = floatToInt(m_base_position + + v3s16 ptop = floatToInt(getBasePosition() + v3f(0.0f, dam_top * BS, 0.0f), BS); MapNode ntop = m_env->getMap().getNode(ptop); const ContentFeatures &c = m_env->getGameDef()->ndef()->get(ntop); @@ -273,7 +272,7 @@ void PlayerSAO::step(float dtime, bool send_recommended) if (isAttached()) pos = m_last_good_position; else - pos = m_base_position; + pos = getBasePosition(); std::string str = generateUpdatePositionCommand( pos, @@ -332,7 +331,7 @@ std::string PlayerSAO::generateUpdatePhysicsOverrideCommand() const void PlayerSAO::setBasePosition(v3f position) { - if (m_player && position != m_base_position) + if (m_player && position != getBasePosition()) m_player->setDirty(true); // This needs to be ran for attachments too @@ -636,7 +635,7 @@ bool PlayerSAO::checkMovementCheat() if (m_is_singleplayer || isAttached() || !(anticheat_flags & AC_MOVEMENT)) { - m_last_good_position = m_base_position; + m_last_good_position = getBasePosition(); return false; } @@ -701,7 +700,7 @@ bool PlayerSAO::checkMovementCheat() if (player_max_jump < 0.0001f) player_max_jump = 0.0001f; - v3f diff = (m_base_position - m_last_good_position); + v3f diff = (getBasePosition() - m_last_good_position); float d_vert = diff.Y; diff.Y = 0; float d_horiz = diff.getLength(); @@ -722,7 +721,7 @@ bool PlayerSAO::checkMovementCheat() required_time /= anticheat_movement_tolerance; if (m_move_pool.grab(required_time)) { - m_last_good_position = m_base_position; + m_last_good_position = getBasePosition(); } else { const float LAG_POOL_MIN = 5.0; float lag_pool_max = m_env->getMaxLagEstimate() * 2.0; @@ -744,8 +743,8 @@ bool PlayerSAO::getCollisionBox(aabb3f *toset) const toset->MinEdge = m_prop.collisionbox.MinEdge * BS; toset->MaxEdge = m_prop.collisionbox.MaxEdge * BS; - toset->MinEdge += m_base_position; - toset->MaxEdge += m_base_position; + toset->MinEdge += getBasePosition(); + toset->MaxEdge += getBasePosition(); return true; } diff --git a/src/server/player_sao.h b/src/server/player_sao.h index 0ce26f7cc..a19177a7e 100644 --- a/src/server/player_sao.h +++ b/src/server/player_sao.h @@ -170,7 +170,7 @@ public: void finalize(RemotePlayer *player, const std::set &privs); - v3f getEyePosition() const { return m_base_position + getEyeOffset(); } + v3f getEyePosition() const { return getBasePosition() + getEyeOffset(); } v3f getEyeOffset() const; float getZoomFOV() const; diff --git a/src/server/serveractiveobject.cpp b/src/server/serveractiveobject.cpp index 913c402ed..fa0c76a70 100644 --- a/src/server/serveractiveobject.cpp +++ b/src/server/serveractiveobject.cpp @@ -6,6 +6,7 @@ #include "inventory.h" #include "inventorymanager.h" #include "constants.h" // BS +#include "serverenvironment.h" ServerActiveObject::ServerActiveObject(ServerEnvironment *env, v3f pos): ActiveObject(0), @@ -14,6 +15,17 @@ ServerActiveObject::ServerActiveObject(ServerEnvironment *env, v3f pos): { } +void ServerActiveObject::setBasePosition(v3f pos) +{ + bool changed = m_base_position != pos; + m_base_position = pos; + if (changed && getEnv()) { + // getEnv() should never be null if the object is in an environment. + // It may however be null e.g. in tests or database migrations. + getEnv()->updateObjectPos(getId(), pos); + } +} + float ServerActiveObject::getMinimumSavedMovement() { return 2.0*BS; diff --git a/src/server/serveractiveobject.h b/src/server/serveractiveobject.h index 8b60bc5f8..da3dc17bd 100644 --- a/src/server/serveractiveobject.h +++ b/src/server/serveractiveobject.h @@ -63,7 +63,7 @@ public: Some simple getters/setters */ v3f getBasePosition() const { return m_base_position; } - void setBasePosition(v3f pos){ m_base_position = pos; } + void setBasePosition(v3f pos); ServerEnvironment* getEnv(){ return m_env; } /* @@ -245,7 +245,6 @@ protected: virtual void onMarkedForRemoval() {} ServerEnvironment *m_env; - v3f m_base_position; std::unordered_set m_attached_particle_spawners; /* @@ -273,4 +272,7 @@ protected: Queue of messages to be sent to the client */ std::queue m_messages_out; + +private: + v3f m_base_position; // setBasePosition updates index and MUST be called }; diff --git a/src/serverenvironment.cpp b/src/serverenvironment.cpp index 697b7b073..55306ee59 100644 --- a/src/serverenvironment.cpp +++ b/src/serverenvironment.cpp @@ -6,6 +6,7 @@ #include #include #include "serverenvironment.h" +#include "irr_aabb3d.h" #include "settings.h" #include "log.h" #include "mapblock.h" @@ -1399,10 +1400,14 @@ void ServerEnvironment::getSelectedActiveObjects( return false; }; + aabb3f search_area(shootline_on_map.start, shootline_on_map.end); + search_area.repair(); + search_area.MinEdge -= 5 * BS; + search_area.MaxEdge += 5 * BS; + // Use "logic in callback" pattern to avoid useless vector filling std::vector tmp; - getObjectsInsideRadius(tmp, shootline_on_map.getMiddle(), - 0.5 * shootline_on_map.getLength() + 5 * BS, process); + getObjectsInArea(tmp, search_area, process); } /* diff --git a/src/serverenvironment.h b/src/serverenvironment.h index c7396987a..04153e944 100644 --- a/src/serverenvironment.h +++ b/src/serverenvironment.h @@ -220,6 +220,11 @@ public: // Find the daylight value at pos with a Depth First Search u8 findSunlight(v3s16 pos) const; + void updateObjectPos(u16 id, v3f pos) + { + return m_ao_manager.updateObjectPos(id, pos); + } + // Find all active objects inside a radius around a point void getObjectsInsideRadius(std::vector &objects, const v3f &pos, float radius, std::function include_obj_cb) diff --git a/src/unittest/CMakeLists.txt b/src/unittest/CMakeLists.txt index 7417819bd..9ac275d7f 100644 --- a/src/unittest/CMakeLists.txt +++ b/src/unittest/CMakeLists.txt @@ -10,6 +10,7 @@ set (UNITTEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_connection.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_craft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_datastructures.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_k_d_tree.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_filesys.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_inventory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_irrptr.cpp diff --git a/src/unittest/test_k_d_tree.cpp b/src/unittest/test_k_d_tree.cpp new file mode 100644 index 000000000..9dbe0b545 --- /dev/null +++ b/src/unittest/test_k_d_tree.cpp @@ -0,0 +1,138 @@ +// Copyright (C) 2024 Lars Müller +// SPDX-License-Identifier: LGPL-2.1-or-later + +#include "catch.h" +#include "irrTypes.h" +#include "noise.h" +#include "util/k_d_tree.h" + +#include +#include + +template +class ObjectVector +{ +public: + using Point = std::array; + + void insert(const Point &p, Id id) + { + entries.push_back(Entry{p, id}); + } + + void remove(Id id) + { + const auto it = std::find_if(entries.begin(), entries.end(), [&](const auto &e) { + return e.id == id; + }); + assert(it != entries.end()); + entries.erase(it); + } + + void update(const Point &p, Id id) + { + remove(id); + insert(p, id); + } + + template + void rangeQuery(const Point &min, const Point &max, const F &cb) + { + for (const auto &e : entries) { + for (uint8_t d = 0; d < Dim; ++d) + if (e.point[d] < min[d] || e.point[d] > max[d]) + goto next; + cb(e.point, e.id); // TODO check + next: {} + } + } + +private: + struct Entry { + Point point; + Id id; + }; + std::vector entries; +}; + +TEST_CASE("k-d-tree") { + +SECTION("single update") { + k_d_tree::DynamicKdTrees<3, u16, u16> kds; + for (u16 i = 1; i <= 5; ++i) + kds.insert({i, i, i}, i); + for (u16 i = 1; i <= 5; ++i) { + u16 j = i - 1; + kds.update({j, j, j}, i); + } +} + +SECTION("random operations") { + PseudoRandom pr(Catch::getSeed()); + + ObjectVector<3, f32, u16> objvec; + k_d_tree::DynamicKdTrees<3, f32, u16> kds; + + const auto randPos = [&]() { + std::array point; + for (uint8_t d = 0; d < 3; ++d) + point[d] = pr.range(-1000, 1000); + return point; + }; + + const auto testRandomQuery = [&]() { + std::array min, max; + for (uint8_t d = 0; d < 3; ++d) { + min[d] = pr.range(-1500, 1500); + max[d] = min[d] + pr.range(1, 2500); + } + std::unordered_set expected_ids; + objvec.rangeQuery(min, max, [&](auto _, u16 id) { + CHECK(expected_ids.count(id) == 0); + expected_ids.insert(id); + }); + kds.rangeQuery(min, max, [&](auto point, u16 id) { + CHECK(expected_ids.count(id) == 1); + expected_ids.erase(id); + }); + CHECK(expected_ids.empty()); + }; + + for (u16 id = 1; id < 1000; ++id) { + const auto point = randPos(); + objvec.insert(point, id); + kds.insert(point, id); + testRandomQuery(); + } + + const auto testRandomQueries = [&]() { + for (int i = 0; i < 1000; ++i) { + testRandomQuery(); + } + }; + + testRandomQueries(); + + for (u16 id = 1; id < 800; ++id) { + objvec.remove(id); + kds.remove(id); + } + + testRandomQueries(); + + for (u16 id = 800; id < 1000; ++id) { + const auto point = randPos(); + objvec.update(point, id); + kds.update(point, id); + } + + testRandomQueries(); + + for (u16 id = 800; id < 1000; ++id) { + objvec.remove(id); + kds.remove(id); + testRandomQuery(); + } +} + +} diff --git a/src/unittest/test_serveractiveobjectmgr.cpp b/src/unittest/test_serveractiveobjectmgr.cpp index 0d370b5c5..861f5e06f 100644 --- a/src/unittest/test_serveractiveobjectmgr.cpp +++ b/src/unittest/test_serveractiveobjectmgr.cpp @@ -2,52 +2,162 @@ // SPDX-License-Identifier: LGPL-2.1-or-later // Copyright (C) 2018 nerzhul, Loic Blot -#include "test.h" +#include "activeobjectmgr.h" +#include "catch.h" +#include "irrTypes.h" +#include "irr_aabb3d.h" #include "mock_serveractiveobject.h" #include -#include +#include +#include +#include #include "server/activeobjectmgr.h" +#include "server/serveractiveobject.h" -#include "profiler.h" +class TestServerActiveObjectMgr { + server::ActiveObjectMgr saomgr; + std::vector ids; - -class TestServerActiveObjectMgr : public TestBase -{ public: - TestServerActiveObjectMgr() { TestManager::registerTestModule(this); } - const char *getName() { return "TestServerActiveObjectMgr"; } - void runTests(IGameDef *gamedef); + u16 getFreeId() const { return saomgr.getFreeId(); } - void testFreeID(); - void testRegisterObject(); - void testRemoveObject(); - void testGetObjectsInsideRadius(); - void testGetAddedActiveObjectsAroundPos(); + bool registerObject(std::unique_ptr obj) + { + auto *ptr = obj.get(); + if (!saomgr.registerObject(std::move(obj))) + return false; + ids.push_back(ptr->getId()); + return true; + } + + void removeObject(u16 id) + { + const auto it = std::find(ids.begin(), ids.end(), id); + REQUIRE(it != ids.end()); + ids.erase(it); + saomgr.removeObject(id); + } + + void updateObjectPos(u16 id, const v3f &pos) + { + auto *obj = saomgr.getActiveObject(id); + REQUIRE(obj != nullptr); + obj->setPos(pos); + saomgr.updateObjectPos(id, pos); // HACK work around m_env == nullptr + } + + void clear() + { + saomgr.clear(); + ids.clear(); + } + + ServerActiveObject *getActiveObject(u16 id) + { + return saomgr.getActiveObject(id); + } + + template + void getObjectsInsideRadius(T&& arg) + { + saomgr.getObjectsInsideRadius(std::forward(arg)); + } + + template + void getAddedActiveObjectsAroundPos(T&& arg) + { + saomgr.getAddedActiveObjectsAroundPos(std::forward(arg)); + } + + // Testing + + bool empty() { return ids.empty(); } + + template + u16 randomId(T &random) + { + REQUIRE(!ids.empty()); + std::uniform_int_distribution index(0, ids.size() - 1); + return ids[index(random)]; + } + + void getObjectsInsideRadiusNaive(const v3f &pos, float radius, + std::vector &result) + { + for (const auto &[id, obj] : saomgr.m_active_objects.iter()) { + if (obj->getBasePosition().getDistanceFromSQ(pos) <= radius * radius) { + result.push_back(obj.get()); + } + } + } + + void getObjectsInAreaNaive(const aabb3f &box, + std::vector &result) + { + for (const auto &[id, obj] : saomgr.m_active_objects.iter()) { + if (box.isPointInside(obj->getBasePosition())) { + result.push_back(obj.get()); + } + } + } + + constexpr static auto compare_by_id = [](auto *sao1, auto *sao2) -> bool { + return sao1->getId() < sao2->getId(); + }; + + static void sortById(std::vector &saos) + { + std::sort(saos.begin(), saos.end(), compare_by_id); + } + + void compareObjects(std::vector &actual, + std::vector &expected) + { + std::vector unexpected, missing; + sortById(actual); + sortById(expected); + + std::set_difference(actual.begin(), actual.end(), + expected.begin(), expected.end(), + std::back_inserter(unexpected), compare_by_id); + + assert(unexpected.empty()); + + std::set_difference(expected.begin(), expected.end(), + actual.begin(), actual.end(), + std::back_inserter(missing), compare_by_id); + assert(missing.empty()); + } + + void compareObjectsInsideRadius(const v3f &pos, float radius) + { + std::vector actual, expected; + saomgr.getObjectsInsideRadius(pos, radius, actual, nullptr); + getObjectsInsideRadiusNaive(pos, radius, expected); + compareObjects(actual, expected); + } + + void compareObjectsInArea(const aabb3f &box) + { + std::vector actual, expected; + saomgr.getObjectsInArea(box, actual, nullptr); + getObjectsInAreaNaive(box, expected); + compareObjects(actual, expected); + } }; -static TestServerActiveObjectMgr g_test_instance; -void TestServerActiveObjectMgr::runTests(IGameDef *gamedef) -{ - TEST(testFreeID); - TEST(testRegisterObject) - TEST(testRemoveObject) - TEST(testGetObjectsInsideRadius); - TEST(testGetAddedActiveObjectsAroundPos); -} +TEST_CASE("server active object manager") { -//////////////////////////////////////////////////////////////////////////////// - -void TestServerActiveObjectMgr::testFreeID() -{ - server::ActiveObjectMgr saomgr; +SECTION("free ID") { + TestServerActiveObjectMgr saomgr; std::vector aoids; u16 aoid = saomgr.getFreeId(); // Ensure it's not the same id - UASSERT(saomgr.getFreeId() != aoid); + REQUIRE(saomgr.getFreeId() != aoid); aoids.push_back(aoid); @@ -60,53 +170,50 @@ void TestServerActiveObjectMgr::testFreeID() aoids.push_back(sao->getId()); // Ensure next id is not in registered list - UASSERT(std::find(aoids.begin(), aoids.end(), saomgr.getFreeId()) == + REQUIRE(std::find(aoids.begin(), aoids.end(), saomgr.getFreeId()) == aoids.end()); } saomgr.clear(); } -void TestServerActiveObjectMgr::testRegisterObject() -{ - server::ActiveObjectMgr saomgr; +SECTION("register object") { + TestServerActiveObjectMgr saomgr; auto sao_u = std::make_unique(); auto sao = sao_u.get(); - UASSERT(saomgr.registerObject(std::move(sao_u))); + REQUIRE(saomgr.registerObject(std::move(sao_u))); u16 id = sao->getId(); auto saoToCompare = saomgr.getActiveObject(id); - UASSERT(saoToCompare->getId() == id); - UASSERT(saoToCompare == sao); + REQUIRE(saoToCompare->getId() == id); + REQUIRE(saoToCompare == sao); sao_u = std::make_unique(); sao = sao_u.get(); - UASSERT(saomgr.registerObject(std::move(sao_u))); - UASSERT(saomgr.getActiveObject(sao->getId()) == sao); - UASSERT(saomgr.getActiveObject(sao->getId()) != saoToCompare); + REQUIRE(saomgr.registerObject(std::move(sao_u))); + REQUIRE(saomgr.getActiveObject(sao->getId()) == sao); + REQUIRE(saomgr.getActiveObject(sao->getId()) != saoToCompare); saomgr.clear(); } -void TestServerActiveObjectMgr::testRemoveObject() -{ - server::ActiveObjectMgr saomgr; +SECTION("remove object") { + TestServerActiveObjectMgr saomgr; auto sao_u = std::make_unique(); auto sao = sao_u.get(); - UASSERT(saomgr.registerObject(std::move(sao_u))); + REQUIRE(saomgr.registerObject(std::move(sao_u))); u16 id = sao->getId(); - UASSERT(saomgr.getActiveObject(id) != nullptr) + REQUIRE(saomgr.getActiveObject(id) != nullptr); saomgr.removeObject(sao->getId()); - UASSERT(saomgr.getActiveObject(id) == nullptr); + REQUIRE(saomgr.getActiveObject(id) == nullptr); saomgr.clear(); } -void TestServerActiveObjectMgr::testGetObjectsInsideRadius() -{ +SECTION("get objects inside radius") { server::ActiveObjectMgr saomgr; static const v3f sao_pos[] = { v3f(10, 40, 10), @@ -122,15 +229,15 @@ void TestServerActiveObjectMgr::testGetObjectsInsideRadius() std::vector result; saomgr.getObjectsInsideRadius(v3f(), 50, result, nullptr); - UASSERTCMP(int, ==, result.size(), 1); + CHECK(result.size() == 1); result.clear(); saomgr.getObjectsInsideRadius(v3f(), 750, result, nullptr); - UASSERTCMP(int, ==, result.size(), 2); + CHECK(result.size() == 2); result.clear(); saomgr.getObjectsInsideRadius(v3f(), 750000, result, nullptr); - UASSERTCMP(int, ==, result.size(), 5); + CHECK(result.size() == 5); result.clear(); auto include_obj_cb = [](ServerActiveObject *obj) { @@ -138,13 +245,12 @@ void TestServerActiveObjectMgr::testGetObjectsInsideRadius() }; saomgr.getObjectsInsideRadius(v3f(), 750000, result, include_obj_cb); - UASSERTCMP(int, ==, result.size(), 4); + CHECK(result.size() == 4); saomgr.clear(); } -void TestServerActiveObjectMgr::testGetAddedActiveObjectsAroundPos() -{ +SECTION("get added active objects around pos") { server::ActiveObjectMgr saomgr; static const v3f sao_pos[] = { v3f(10, 40, 10), @@ -161,12 +267,64 @@ void TestServerActiveObjectMgr::testGetAddedActiveObjectsAroundPos() std::vector result; std::set cur_objects; saomgr.getAddedActiveObjectsAroundPos(v3f(), "singleplayer", 100, 50, cur_objects, result); - UASSERTCMP(int, ==, result.size(), 1); + CHECK(result.size() == 1); result.clear(); cur_objects.clear(); saomgr.getAddedActiveObjectsAroundPos(v3f(), "singleplayer", 740, 50, cur_objects, result); - UASSERTCMP(int, ==, result.size(), 2); + CHECK(result.size() == 2); saomgr.clear(); } + +SECTION("spatial index") { + TestServerActiveObjectMgr saomgr; + std::mt19937 gen(0xABCDEF); + std::uniform_int_distribution coordinate(-1000, 1000); + const auto random_pos = [&]() { + return v3f(coordinate(gen), coordinate(gen), coordinate(gen)); + }; + + std::uniform_int_distribution percent(0, 99); + const auto modify = [&](u32 p_insert, u32 p_delete, u32 p_update) { + const auto p = percent(gen); + if (p < p_insert) { + saomgr.registerObject(std::make_unique(nullptr, random_pos())); + } else if (p < p_insert + p_delete) { + if (!saomgr.empty()) + saomgr.removeObject(saomgr.randomId(gen)); + } else if (p < p_insert + p_delete + p_update) { + if (!saomgr.empty()) + saomgr.updateObjectPos(saomgr.randomId(gen), random_pos()); + } + }; + + const auto test_queries = [&]() { + std::uniform_real_distribution radius(0, 100); + saomgr.compareObjectsInsideRadius(random_pos(), radius(gen)); + + aabb3f box(random_pos(), random_pos()); + box.repair(); + saomgr.compareObjectsInArea(box); + }; + + // Grow: Insertion twice as likely as deletion + for (u32 i = 0; i < 3000; ++i) { + modify(50, 25, 25); + test_queries(); + } + + // Stagnate: Insertion and deletion equally likely + for (u32 i = 0; i < 3000; ++i) { + modify(25, 25, 50); + test_queries(); + } + + // Shrink: Deletion twice as likely as insertion + while (!saomgr.empty()) { + modify(25, 50, 25); + test_queries(); + } +} + +} diff --git a/src/util/k_d_tree.h b/src/util/k_d_tree.h new file mode 100644 index 000000000..f8e266a36 --- /dev/null +++ b/src/util/k_d_tree.h @@ -0,0 +1,515 @@ +// Copyright (C) 2024 Lars Müller +// SPDX-License-Identifier: LGPL-2.1-or-later + +#pragma once + +#include +#include +#include +#include +#include +#include + +/* +This implements a dynamic forest of static k-d-trees. + +A k-d-tree is a k-dimensional binary search tree. +On the i-th level of the tree, you split by the (i mod k)-th coordinate. + +Building a balanced k-d-tree for n points is done in O(n log n) time: +Points are stored in a matrix, identified by indices. +These indices are presorted by all k axes. +To split, you simply pick the pivot index in the appropriate index array, +and mark all points left to it by index in a bitset. +This lets you then split the indices sorted by the other axes, +while preserving the sorted order. + +This however only gives us a static spatial index. +To make it dynamic, we keep a "forest" of k-d-trees of sizes of successive powers of two. +When we insert a new tree, we check whether there already is a k-d-tree of the same size. +If this is the case, we merge with that tree, giving us a tree of twice the size, +and so on, until we find a free size. + +This means our "forest" corresponds to a bit pattern, +where a set bit means a non-empty tree. +Inserting a point is equivalent to incrementing this bit pattern. + +To handle deletions, we simply mark the appropriate point as deleted using another bitset. +When more than half the points have been deleted, +we shrink the structure by removing all deleted points. +This is equivalent to shifting down the "bit pattern" by one. + +There are plenty variations that could be explored: + +* Keeping a small amount of points in a small pool to make updates faster - + avoid building and querying small k-d-trees. + This might be useful if the overhead for small sizes hurts performance. +* Keeping fewer trees to make queries faster, at the expense of updates. +* More eagerly removing entries marked as deleted (for example, on merge). +* Replacing the array-backed structure with a structure of dynamically allocated nodes. + This would make it possible to "let trees get out of shape". +* Shrinking the structure currently sorts the live points by all axes, + not leveraging the existing presorting of the subsets. + Cleverly done filtering followed by sorted merges should enable linear time. +* A special ray proximity query could be implemented. This is tricky however. +*/ + +namespace k_d_tree +{ + +using Idx = uint16_t; + +// We use size_t for sizes (but not for indices) +// to make sure there are no wraparounds when we approach the limit. +// This hardly affects performance or memory usage; +// the core arrays still only store indices. + +template +class Points +{ +public: + using Point = std::array; + //! Empty + Points() : n(0), coords(nullptr) {} + //! Allocating constructor; leaves coords uninitialized! + Points(size_t n) : n(n), coords(new Component[Dim * n]) {} + //! Copying constructor + Points(size_t n, const std::array &coords) + : Points(n) + { + for (uint8_t d = 0; d < Dim; ++d) + std::copy(coords[d], coords[d] + n, begin(d)); + } + + size_t size() const { return n; } + + void assign(Idx start, const Points &from) + { + for (uint8_t d = 0; d < Dim; ++d) + std::copy(from.begin(d), from.end(d), begin(d) + start); + } + + Point getPoint(Idx i) const + { + Point point; + for (uint8_t d = 0; d < Dim; ++d) + point[d] = begin(d)[i]; + return point; + } + + void setPoint(Idx i, const Point &point) + { + for (uint8_t d = 0; d < Dim; ++d) + begin(d)[i] = point[d]; + } + + Component *begin(uint8_t d) { return coords.get() + d * n; } + Component *end(uint8_t d) { return begin(d) + n; } + const Component *begin(uint8_t d) const { return coords.get() + d * n; } + const Component *end(uint8_t d) const { return begin(d) + n; } + +private: + size_t n; + std::unique_ptr coords; +}; + +template +class SortedIndices +{ +public: + //! empty + SortedIndices() : indices() {} + + //! uninitialized indices + static SortedIndices newUninitialized(size_t n) + { + return SortedIndices(Points(n)); + } + + //! Identity permutation on all axes + SortedIndices(size_t n) + : indices(n) + { + for (uint8_t d = 0; d < Dim; ++d) { + for (Idx i = 0; i < n; ++i) { + indices.begin(d)[i] = i; + } + } + } + + size_t size() const { return indices.size(); } + bool empty() const { return size() == 0; } + + struct SplitResult { + SortedIndices left, right; + Idx pivot; + }; + + //! Splits the sorted indices in the middle along the specified axis, + //! partitioning them into left (<=), the pivot, and right (>=). + SplitResult split(uint8_t axis, std::vector &markers) const + { + const auto begin = indices.begin(axis); + Idx left_n = indices.size() / 2; + const auto mid = begin + left_n; + + // Mark all points to be partitioned left + for (auto it = begin; it != mid; ++it) + markers[*it] = true; + + SortedIndices left(left_n); + std::copy(begin, mid, left.indices.begin(axis)); + SortedIndices right(indices.size() - left_n - 1); + std::copy(mid + 1, indices.end(axis), right.indices.begin(axis)); + + for (uint8_t d = 0; d < Dim; ++d) { + if (d == axis) + continue; + auto left_ptr = left.indices.begin(d); + auto right_ptr = right.indices.begin(d); + for (auto it = indices.begin(d); it != indices.end(d); ++it) { + if (*it != *mid) { // ignore pivot + if (markers[*it]) + *(left_ptr++) = *it; + else + *(right_ptr++) = *it; + } + } + } + + // Unmark points, since we want to reuse the storage for markers + for (auto it = begin; it != mid; ++it) + markers[*it] = false; + + return SplitResult{std::move(left), std::move(right), *mid}; + } + + Idx *begin(uint8_t d) { return indices.begin(d); } + Idx *end(uint8_t d) { return indices.end(d); } + const Idx *begin(uint8_t d) const { return indices.begin(d); } + const Idx *end(uint8_t d) const { return indices.end(d); } +private: + SortedIndices(Points &&indices) : indices(std::move(indices)) {} + Points indices; +}; + +template +class SortedPoints +{ +public: + SortedPoints() : points(), indices() {} + + //! Single point + SortedPoints(const std::array &point) + : points(1), indices(1) + { + points.setPoint(0, point); + } + + //! Sort points + SortedPoints(size_t n, const std::array ptrs) + : points(n, ptrs), indices(n) + { + for (uint8_t d = 0; d < Dim; ++d) { + const auto coord = points.begin(d); + std::sort(indices.begin(d), indices.end(d), [&](auto i, auto j) { + return coord[i] < coord[j]; + }); + } + } + + //! Merge two sets of sorted points + SortedPoints(const SortedPoints &a, const SortedPoints &b) + : points(a.size() + b.size()) + { + const auto n = points.size(); + indices = SortedIndices::newUninitialized(n); + for (uint8_t d = 0; d < Dim; ++d) { + points.assign(0, a.points); + points.assign(a.points.size(), b.points); + const auto coord = points.begin(d); + auto a_ptr = a.indices.begin(d); + auto b_ptr = b.indices.begin(d); + auto dst_ptr = indices.begin(d); + while (a_ptr != a.indices.end(d) && b_ptr != b.indices.end(d)) { + const auto i = *a_ptr; + const auto j = *b_ptr + a.size(); + if (coord[i] <= coord[j]) { + *(dst_ptr++) = i; + ++a_ptr; + } else { + *(dst_ptr++) = j; + ++b_ptr; + } + } + while (a_ptr != a.indices.end(d)) + *(dst_ptr++) = *(a_ptr++); + while (b_ptr != b.indices.end(d)) + *(dst_ptr++) = a.size() + *(b_ptr++); + } + } + + size_t size() const + { + // technically redundant with indices.size(), + // but that is irrelevant + return points.size(); + } + + Points points; + SortedIndices indices; +}; + +template +class KdTree +{ +public: + using Point = std::array; + + //! Empty tree + KdTree() + : items() + , ids(nullptr) + , tree(nullptr) + , deleted() + {} + + //! Build a tree containing just a single point + KdTree(const Point &point, const Id &id) + : items(point) + , ids(std::make_unique(1)) + , tree(std::make_unique(1)) + , deleted(1) + { + tree[0] = 0; + ids[0] = id; + } + + //! Build a tree + KdTree(size_t n, Id const *ids, std::array pts) + : items(n, pts) + , ids(std::make_unique(n)) + , tree(std::make_unique(n)) + , deleted(n) + { + std::copy(ids, ids + n, this->ids.get()); + init(0, 0, items.indices); + } + + //! Merge two trees. Both trees are assumed to have a power of two size. + KdTree(const KdTree &a, const KdTree &b) + : items(a.items, b.items) + { + tree = std::make_unique(cap()); + ids = std::make_unique(cap()); + std::copy(a.ids.get(), a.ids.get() + a.cap(), ids.get()); + std::copy(b.ids.get(), b.ids.get() + b.cap(), ids.get() + a.cap()); + // Note: Initialize `deleted` *before* calling `init`, + // since `init` abuses the `deleted` marks as left/right marks. + deleted = std::vector(cap()); + init(0, 0, items.indices); + std::copy(a.deleted.begin(), a.deleted.end(), deleted.begin()); + std::copy(b.deleted.begin(), b.deleted.end(), deleted.begin() + a.items.size()); + } + + template + void rangeQuery(const Point &min, const Point &max, + const F &cb) const + { + rangeQuery(0, 0, min, max, cb); + } + + void remove(Idx internalIdx) + { + assert(!deleted[internalIdx]); + deleted[internalIdx] = true; + } + + template + void foreach(F cb) const + { + for (Idx i = 0; i < cap(); ++i) { + if (!deleted[i]) { + cb(i, items.points.getPoint(i), ids[i]); + } + } + } + + //! Capacity, not size, since some items may be marked as deleted + size_t cap() const { return items.size(); } + +private: + void init(Idx root, uint8_t axis, const SortedIndices &sorted) + { + // Temporarily abuse "deleted" marks as left/right marks + const auto split = sorted.split(axis, deleted); + tree[root] = split.pivot; + const auto next_axis = (axis + 1) % Dim; + if (!split.left.empty()) + init(2 * root + 1, next_axis, split.left); + if (!split.right.empty()) + init(2 * root + 2, next_axis, split.right); + } + + template + // Note: root is of type size_t to avoid issues with wraparound + void rangeQuery(size_t root, uint8_t split, + const Point &min, const Point &max, + const F &cb) const + { + if (root >= cap()) + return; + const auto ptid = tree[root]; + const auto coord = items.points.begin(split)[ptid]; + const auto leftChild = 2*root + 1; + const auto rightChild = 2*root + 2; + const auto nextSplit = (split + 1) % Dim; + if (min[split] > coord) { + rangeQuery(rightChild, nextSplit, min, max, cb); + } else if (max[split] < coord) { + rangeQuery(leftChild, nextSplit, min, max, cb); + } else { + rangeQuery(rightChild, nextSplit, min, max, cb); + rangeQuery(leftChild, nextSplit, min, max, cb); + if (deleted[ptid]) + return; + const auto point = items.points.getPoint(ptid); + for (uint8_t d = 0; d < Dim; ++d) + if (point[d] < min[d] || point[d] > max[d]) + return; + cb(point, ids[ptid]); + } + } + SortedPoints items; + std::unique_ptr ids; + std::unique_ptr tree; + std::vector deleted; +}; + +template +class DynamicKdTrees +{ + using Tree = KdTree; + +public: + using Point = typename Tree::Point; + + void insert(const std::array &point, Id id) + { + Tree tree(point, id); + for (uint8_t tree_idx = 0;; ++tree_idx) { + if (tree_idx == trees.size()) { + trees.push_back(std::move(tree)); + updateDelEntries(tree_idx); + break; + } + // Can we use a free slot to "plant" the tree? + if (trees[tree_idx].cap() == 0) { + trees[tree_idx] = std::move(tree); + updateDelEntries(tree_idx); + break; + } + tree = Tree(tree, trees[tree_idx]); + trees[tree_idx] = std::move(Tree()); + } + ++n_entries; + } + + void remove(Id id) + { + const auto it = del_entries.find(id); + assert(it != del_entries.end()); + trees.at(it->second.tree_idx).remove(it->second.in_tree); + del_entries.erase(it); + ++deleted; + if (deleted >= (n_entries+1)/2) // "shift out" the last tree + shrink_to_half(); + } + + void update(const Point &newPos, Id id) + { + remove(id); + insert(newPos, id); + } + + template + void rangeQuery(const Point &min, const Point &max, + const F &cb) const + { + for (const auto &tree : trees) + tree.rangeQuery(min, max, cb); + } + + size_t size() const + { + return n_entries - deleted; + } + +private: + + void updateDelEntries(uint8_t tree_idx) + { + trees[tree_idx].foreach([&](Idx in_tree_idx, auto _, Id id) { + del_entries[id] = {tree_idx, in_tree_idx}; + }); + } + + // Shrink to half the size, equivalent to shifting down the "bit pattern". + void shrink_to_half() + { + assert(n_entries >= deleted); + assert(n_entries - deleted == (n_entries >> 1)); + n_entries -= deleted; + deleted = 0; + // Reset map, freeing memory (instead of clearing) + del_entries = std::unordered_map(); + + // Collect all live points and corresponding IDs. + const auto live_ids = std::make_unique(n_entries); + Points live_points(n_entries); + size_t i = 0; + for (const auto &tree : trees) { + tree.foreach([&](Idx _, auto point, Id id) { + assert(i < n_entries); + live_points.setPoint(static_cast(i), point); + live_ids[i] = id; + ++i; + }); + } + assert(i == n_entries); + + // Construct a new forest. + // The "tree pattern" will effectively just be shifted down by one. + auto id_ptr = live_ids.get(); + std::array point_ptrs; + size_t n = 1; + for (uint8_t d = 0; d < Dim; ++d) + point_ptrs[d] = live_points.begin(d); + for (uint8_t tree_idx = 0; tree_idx < trees.size() - 1; ++tree_idx, n *= 2) { + Tree tree; + // If there was a tree at the next position, there should be + // a tree at this position after shifting the pattern. + if (trees[tree_idx+1].cap() > 0) { + tree = std::move(Tree(n, id_ptr, point_ptrs)); + id_ptr += n; + for (uint8_t d = 0; d < Dim; ++d) + point_ptrs[d] += n; + } + trees[tree_idx] = std::move(tree); + updateDelEntries(tree_idx); + } + trees.pop_back(); // "shift out" tree with the most elements + } + // This could even use an array instead of a vector, + // since the number of trees is guaranteed to be logarithmic in the max of Idx + std::vector trees; + struct DelEntry { + uint8_t tree_idx; + Idx in_tree; + }; + std::unordered_map del_entries; + size_t n_entries = 0; + size_t deleted = 0; +}; + +} // end namespace k_d_tree \ No newline at end of file