我已经构建了一个 d
维 KD-树。我想在这个树上执行区域搜索。维基百科 中提到了 KD-树中的范围搜索,但没有详细讲述其实现/算法。请问有人能帮我吗?即使不是对于任意的 d
,也至少希望得到对于 d = 2
和 d = 3
的帮助。谢谢!
我已经构建了一个 d
维 KD-树。我想在这个树上执行区域搜索。维基百科 中提到了 KD-树中的范围搜索,但没有详细讲述其实现/算法。请问有人能帮我吗?即使不是对于任意的 d
,也至少希望得到对于 d = 2
和 d = 3
的帮助。谢谢!
kd-tree有多种变体,我使用的变体具有以下规格:
maxCapacity
个点。注:还有版本是每个节点(无论是内部节点还是叶节点)都仅存储一个点。下面的算法也适用于这些版本。主要区别在于buildTree
。
我大约两年前写了一个算法,感谢@9mat指向的资源。
假设任务是查找位于给定超矩形中(d维)的点数。这项任务也可以列出所有点或列出满足某些其他条件的范围内的所有点,但对我的代码进行直接更改即可。
定义基本节点类如下:
template <typename T> class kdNode{
public: kdNode(){}
virtual long rangeQuery(const T* q_min, const T* q_max) const{ return 0; }
};
那么,一个内部节点(非叶节点)可能长这样:
class internalNode:public kdNode<T>{
const kdNode<T> *left = nullptr, *right = nullptr; // left and right sub trees
int axis; // the axis on which split of points is being done
T value; // the value based on which points are being split
public: internalNode(){}
void buildTree(...){
// builds the tree recursively
}
// returns the number of points in this sub tree that lie inside the hyper rectangle formed by q_min and q_max
int rangeQuery(const T* q_min, const T* q_max) const{
// num of points that satisfy range query conditions
int rangeCount = 0;
// check for left node
if(q_min[axis] <= value) {
rangeCount += left->rangeQuery(q_min, q_max);
}
// check for right node
if(q_max[axis] >= value) {
rangeCount += right->rangeQuery(q_min, q_max);
}
return rangeCount;
}
};
class leaf:public kdNode<T>{
// maxCapacity is a hyper - param, the max num of points you allow a node to hold
array<T, d> points[maxCapacity];
int keyCount = 0; // this is the actual num of points in this leaf (keyCount <= maxCapacity)
public: leaf(){}
public: void addPoint(const T* p){
// add a point p to the leaf node
}
// check if points[index] lies inside the hyper rectangle formed by q_min and q_max
inline bool containsPoint(const int index, const T* q_min, const T* q_max) const{
for (int i=0; i<d; i++) {
if (points[index][i] > q_max[i] || points[index][i] < q_min[i]) {
return false;
}
}
return true;
}
// returns number of points in this leaf node that lie inside the hyper rectangle formed by q_min and q_max
int rangeQuery(const T* q_min, const T* q_max) const{
// num of points that satisfy range query conditions
int rangeCount = 0;
for(int i=0; i < this->keyCount; i++) {
if(containsPoint(i, q_min, q_max)) {
rangeCount++;
}
}
return rangeCount;
}
};
axis
排序,因此您可以使用q_min
和q_max
执行二分查找以查找l
和r
值,然后从l
到r
进行线性搜索,而不是从0
到keyCount-1
(当然,在最坏的情况下这没有帮助,但在实践中,特别是如果您具有相当高的容量值,这可能会有所帮助)。get_range
函数在末尾具有可变参数,并且可以像以下方式调用,
x1,y1,x2,y2
或
x1,y1,z1,x2,y2,z2
等。 首先给出范围的低值,然后是高值。
(您可以使用任意多个维度)。static public <T> void get_range(K_D_Tree<T> tree, List<T> result, float... range) {
if (tree.root == null) return;
float[] node_region = new float[tree.DIMENSIONS * 2];
for (int i = 0; i < tree.DIMENSIONS; i++) {
node_region[i] = -Float.MAX_VALUE;
node_region[i+tree.DIMENSIONS] = Float.MAX_VALUE;
}
_get_range(tree, result, tree.root, node_region, 0, range);
}
node_region
表示节点所在的区域,我们从尽可能大的区域开始。因为我们不知道这是否是我们正在处理的区域。
这里是递归 _get_range
的实现:
static public <T> void _get_range(K_D_Tree<T> tree, List<T> result, K_D_Tree_Node<T> node, float[] node_region, int dimension, float[] target_region) {
if (dimension == tree.DIMENSIONS) dimension = 0;
if (_contains_region(tree, node_region, target_region)) {
_add_whole_branch(node, result);
}
else {
float value = _value(tree, dimension, node);
if (node.left != null) {
float[] node_region_left = new float[tree.DIMENSIONS*2];
System.arraycopy(node_region, 0, node_region_left, 0, node_region.length);
node_region_left[dimension + tree.DIMENSIONS] = value;
if (_intersects_region(tree, node_region_left, target_region)){
_get_range(tree, result, node.left, node_region_left, dimension+1, target_region);
}
}
if (node.right != null) {
float[] node_region_right = new float[tree.DIMENSIONS*2];
System.arraycopy(node_region, 0, node_region_right, 0, node_region.length);
node_region_right[dimension] = value;
if (_intersects_region(tree, node_region_right, target_region)){
_get_range(tree, result, node.right, node_region_right, dimension+1, target_region);
}
}
if (_region_contains_node(tree, target_region, node)) {
result.add(node.point);
}
}
}
if (_contains_region(tree, node_region, target_region)) {
_add_whole_branch(node, result);
}
使用KD-Tree进行范围搜索时,节点的区域有三种选项:
一旦确定一个区域完全包含,则可以添加整个分支而无需进行任何维度检查。
为了更清楚地说明,这里是_add_whole_branch
:
static public <T> void _add_whole_branch(K_D_Tree_Node<T> node, List<T> result) {
result.add(node.point);
if (node.left != null) _add_whole_branch(node.left, result);
if (node.right != null) _add_whole_branch(node.right, result);
}
在这张图片中,所有大的白点都是使用_add_whole_branch
添加的,而仅对于红点需要检查所有维度。
优化
1)
与其从根节点开始_get_range
函数,不如找到分裂节点。这是第一个其点在查询范围内的节点。要找到分裂节点,您仍需要从根节点开始,但计算会更便宜(因为您只需向左或向右移动)。
2)
现在我创建了float[] node_region_left
和float[] node_region_right
,由于这发生在递归函数中,可能会导致相当多的数组。但是,您可以重复使用左侧的数组来处理右侧的数组。出于清晰起见,我在此示例中没有这样做。
我还可以想象将区域大小存储在节点中,但这需要更多的内存,并可能导致许多缓存未命中。
d=1
和d=2
情况的伪代码。 - 9matd
值,循环遍历节点/轴。 - Ankit Kumar