数百万个三维点:如何找到离给定点最近的十个点?

71

一个3D点由(x,y,z)定义。 任意两个点(X,Y,Z)和(x,y,z)之间的距离是d= Sqrt [(X-x)^2 +(Y-y)^2 +(Z-z)^2]。 现在有一个包含一百万个条目的文件,每个条目都是空间中的一些点,没有特定顺序。 给定任何点(a,b,c),找到离它最近的10个点。 您将如何存储这一百万个点,并从该数据结构中检索这10个点。

答案: 我们可以使用KD树来存储这一百万个点,因为KD树是一种高维数据结构,可以用于快速查找最近邻居。对于每个节点,树按照特征进行拆分,并将点分配给相应的子节点,直到叶子节点包含单个点。要查找最近的10个点,请遍历树以查找最接近(a,b,c)的叶子节点。然后,向上回溯树,检查每个祖先节点是否有更接近(a,b,c)的点。如果是,则将其子节点中的其他点添加到候选列表中。最后,按与(a,b,c)的距离排序并选择前10个点。

1
你是在被告知点(a,b,c)之前还是之后创建并填充数据结构?例如,如果您先创建数据结构,然后用户输入(a,b,c)并希望立即得到答案,则David的答案无效。 - Tyler
3
好的观点(无双关语!)当然,如果事先不知道(a,b,c),那么更多是优化现有点列表以便根据3D位置进行搜索,而不是实际执行搜索的问题。 - David Z
6
需要澄清的是,是否需要考虑准备数据结构和将百万点存储在该数据结构中的成本,或者仅考虑检索性能。如果这个成本并不重要,那么无论你检索这些点的次数有多少,kd树都会胜出。但如果这个成本很重要,那么你还应该指定你期望运行搜索的次数(对于少量的搜索来说暴力搜索会胜出,对于更大的kd搜索会胜出)。 - Unreason
12个回答

97

百万个点是一个较小的数字。最直接的方法在这里起作用(基于KDTree的代码对于仅查询一个点而言速度较慢)。

暴力算法(时间约为1秒)

#!/usr/bin/env python
import numpy

NDIM = 3 # number of dimensions

# read points into array
a = numpy.fromfile('million_3D_points.txt', sep=' ')
a.shape = a.size / NDIM, NDIM

point = numpy.random.uniform(0, 100, NDIM) # choose random point
print 'point:', point
d = ((a-point)**2).sum(axis=1)  # compute distances
ndx = d.argsort() # indirect sort 

# print 10 nearest points to the chosen one
import pprint
pprint.pprint(zip(a[ndx[:10]], d[ndx[:10]]))

运行它:

$ time python nearest.py 
point: [ 69.06310224   2.23409409  50.41979143]
[(array([ 69.,   2.,  50.]), 0.23500677815852947),
 (array([ 69.,   2.,  51.]), 0.39542392750839772),
 (array([ 69.,   3.,  50.]), 0.76681859086988302),
 (array([ 69.,   3.,  50.]), 0.76681859086988302),
 (array([ 69.,   3.,  51.]), 0.9272357402197513),
 (array([ 70.,   2.,  50.]), 1.1088022980015722),
 (array([ 70.,   2.,  51.]), 1.2692194473514404),
 (array([ 70.,   2.,  51.]), 1.2692194473514404),
 (array([ 70.,   3.,  51.]), 1.801031260062794),
 (array([ 69.,   1.,  51.]), 1.8636121147970444)]

real    0m1.122s
user    0m1.010s
sys 0m0.120s

以下是生成100万个3D点的脚本:

这里是生成100万个3D点的脚本:

#!/usr/bin/env python
import random
for _ in xrange(10**6):
    print ' '.join(str(random.randrange(100)) for _ in range(3))

输出:

$ head million_3D_points.txt

18 56 26
19 35 74
47 43 71
82 63 28
43 82 0
34 40 16
75 85 69
88 58 3
0 63 90
81 78 98

您可以使用该代码测试更复杂的数据结构和算法(例如,它们是否比上述最简单的方法实际消耗更少的内存或更快)。值得注意的是,目前它是唯一包含可工作代码的答案。

基于KDTree的解决方案 (时间约为1.4秒)

#!/usr/bin/env python
import numpy

NDIM = 3 # number of dimensions

# read points into array
a = numpy.fromfile('million_3D_points.txt', sep=' ')
a.shape = a.size / NDIM, NDIM

point =  [ 69.06310224,   2.23409409,  50.41979143] # use the same point as above
print 'point:', point


from scipy.spatial import KDTree

# find 10 nearest points
tree = KDTree(a, leafsize=a.shape[0]+1)
distances, ndx = tree.query([point], k=10)

# print 10 nearest points to the chosen one
print a[ndx]

运行它:
$ time python nearest_kdtree.py  

point: [69.063102240000006, 2.2340940900000001, 50.419791429999997]
[[[ 69.   2.  50.]
  [ 69.   2.  51.]
  [ 69.   3.  50.]
  [ 69.   3.  50.]
  [ 69.   3.  51.]
  [ 70.   2.  50.]
  [ 70.   2.  51.]
  [ 70.   2.  51.]
  [ 70.   3.  51.]
  [ 69.   1.  51.]]]

real    0m1.359s
user    0m1.280s
sys 0m0.080s

C++中的部分排序(时间约为1.1秒)

// $ g++ nearest.cc && (time ./a.out < million_3D_points.txt )
#include <algorithm>
#include <iostream>
#include <vector>

#include <boost/lambda/lambda.hpp>  // _1
#include <boost/lambda/bind.hpp>    // bind()
#include <boost/tuple/tuple_io.hpp>

namespace {
  typedef double coord_t;
  typedef boost::tuple<coord_t,coord_t,coord_t> point_t;

  coord_t distance_sq(const point_t& a, const point_t& b) { // or boost::geometry::distance
    coord_t x = a.get<0>() - b.get<0>();
    coord_t y = a.get<1>() - b.get<1>();
    coord_t z = a.get<2>() - b.get<2>();
    return x*x + y*y + z*z;
  }
}

int main() {
  using namespace std;
  using namespace boost::lambda; // _1, _2, bind()

  // read array from stdin
  vector<point_t> points;
  cin.exceptions(ios::badbit); // throw exception on bad input
  while(cin) {
    coord_t x,y,z;
    cin >> x >> y >> z;    
    points.push_back(boost::make_tuple(x,y,z));
  }

  // use point value from previous examples
  point_t point(69.06310224, 2.23409409, 50.41979143);
  cout << "point: " << point << endl;  // 1.14s

  // find 10 nearest points using partial_sort() 
  // Complexity: O(N)*log(m) comparisons (O(N)*log(N) worst case for the implementation)
  const size_t m = 10;
  partial_sort(points.begin(), points.begin() + m, points.end(), 
               bind(less<coord_t>(), // compare by distance to the point
                    bind(distance_sq, _1, point), 
                    bind(distance_sq, _2, point)));
  for_each(points.begin(), points.begin() + m, cout << _1 << "\n"); // 1.16s
}

运行它:

g++ -O3 nearest.cc && (time ./a.out < million_3D_points.txt )
point: (69.0631 2.23409 50.4198)
(69 2 50)
(69 2 51)
(69 3 50)
(69 3 50)
(69 3 51)
(70 2 50)
(70 2 51)
(70 2 51)
(70 3 51)
(69 1 51)

real    0m1.152s
user    0m1.140s
sys 0m0.010s

C++中的优先队列(时间约为1.2秒)

#include <algorithm>           // make_heap
#include <functional>          // binary_function<>
#include <iostream>

#include <boost/range.hpp>     // boost::begin(), boost::end()
#include <boost/tr1/tuple.hpp> // get<>, tuple<>, cout <<

namespace {
  typedef double coord_t;
  typedef std::tr1::tuple<coord_t,coord_t,coord_t> point_t;

  // calculate distance (squared) between points `a` & `b`
  coord_t distance_sq(const point_t& a, const point_t& b) { 
    // boost::geometry::distance() squared
    using std::tr1::get;
    coord_t x = get<0>(a) - get<0>(b);
    coord_t y = get<1>(a) - get<1>(b);
    coord_t z = get<2>(a) - get<2>(b);
    return x*x + y*y + z*z;
  }

  // read from input stream `in` to the point `point_out`
  std::istream& getpoint(std::istream& in, point_t& point_out) {    
    using std::tr1::get;
    return (in >> get<0>(point_out) >> get<1>(point_out) >> get<2>(point_out));
  }

  // Adaptable binary predicate that defines whether the first
  // argument is nearer than the second one to given reference point
  template<class T>
  class less_distance : public std::binary_function<T, T, bool> {
    const T& point;
  public:
    less_distance(const T& reference_point) : point(reference_point) {}

    bool operator () (const T& a, const T& b) const {
      return distance_sq(a, point) < distance_sq(b, point);
    } 
  };
}

int main() {
  using namespace std;

  // use point value from previous examples
  point_t point(69.06310224, 2.23409409, 50.41979143);
  cout << "point: " << point << endl;

  const size_t nneighbours = 10; // number of nearest neighbours to find
  point_t points[nneighbours+1];

  // populate `points`
  for (size_t i = 0; getpoint(cin, points[i]) && i < nneighbours; ++i)
    ;

  less_distance<point_t> less_distance_point(point);
  make_heap  (boost::begin(points), boost::end(points), less_distance_point);

  // Complexity: O(N*log(m))
  while(getpoint(cin, points[nneighbours])) {
    // add points[-1] to the heap; O(log(m))
    push_heap(boost::begin(points), boost::end(points), less_distance_point); 
    // remove (move to last position) the most distant from the
    // `point` point; O(log(m))
    pop_heap (boost::begin(points), boost::end(points), less_distance_point);
  }

  // print results
  push_heap  (boost::begin(points), boost::end(points), less_distance_point);
  //   O(m*log(m))
  sort_heap  (boost::begin(points), boost::end(points), less_distance_point);
  for (size_t i = 0; i < nneighbours; ++i) {
    cout << points[i] << ' ' << distance_sq(points[i], point) << '\n';  
  }
}

运行它:

$ g++ -O3 nearest.cc && (time ./a.out < million_3D_points.txt )

point: (69.0631 2.23409 50.4198)
(69 2 50) 0.235007
(69 2 51) 0.395424
(69 3 50) 0.766819
(69 3 50) 0.766819
(69 3 51) 0.927236
(70 2 50) 1.1088
(70 2 51) 1.26922
(70 2 51) 1.26922
(70 3 51) 1.80103
(69 1 51) 1.86361

real    0m1.174s
user    0m1.180s
sys 0m0.000s

基于线性搜索的方法(时间约为1.15秒)

// $ g++ -O3 nearest.cc && (time ./a.out < million_3D_points.txt )
#include <algorithm>           // sort
#include <functional>          // binary_function<>
#include <iostream>

#include <boost/foreach.hpp>
#include <boost/range.hpp>     // begin(), end()
#include <boost/tr1/tuple.hpp> // get<>, tuple<>, cout <<

#define foreach BOOST_FOREACH

namespace {
  typedef double coord_t;
  typedef std::tr1::tuple<coord_t,coord_t,coord_t> point_t;

  // calculate distance (squared) between points `a` & `b`
  coord_t distance_sq(const point_t& a, const point_t& b);

  // read from input stream `in` to the point `point_out`
  std::istream& getpoint(std::istream& in, point_t& point_out);    

  // Adaptable binary predicate that defines whether the first
  // argument is nearer than the second one to given reference point
  class less_distance : public std::binary_function<point_t, point_t, bool> {
    const point_t& point;
  public:
    explicit less_distance(const point_t& reference_point) 
        : point(reference_point) {}
    bool operator () (const point_t& a, const point_t& b) const {
      return distance_sq(a, point) < distance_sq(b, point);
    } 
  };
}

int main() {
  using namespace std;

  // use point value from previous examples
  point_t point(69.06310224, 2.23409409, 50.41979143);
  cout << "point: " << point << endl;
  less_distance nearer(point);

  const size_t nneighbours = 10; // number of nearest neighbours to find
  point_t points[nneighbours];

  // populate `points`
  foreach (point_t& p, points)
    if (! getpoint(cin, p))
      break;

  // Complexity: O(N*m)
  point_t current_point;
  while(cin) {
    getpoint(cin, current_point); //NOTE: `cin` fails after the last
                                  //point, so one can't lift it up to
                                  //the while condition

    // move to the last position the most distant from the
    // `point` point; O(m)
    foreach (point_t& p, points)
      if (nearer(current_point, p)) 
        // found point that is nearer to the `point` 

        //NOTE: could use insert (on sorted sequence) & break instead
        //of swap but in that case it might be better to use
        //heap-based algorithm altogether
        std::swap(current_point, p);
  }

  // print results;  O(m*log(m))
  sort(boost::begin(points), boost::end(points), nearer);
  foreach (point_t p, points)
    cout << p << ' ' << distance_sq(p, point) << '\n';  
}

namespace {
  coord_t distance_sq(const point_t& a, const point_t& b) { 
    // boost::geometry::distance() squared
    using std::tr1::get;
    coord_t x = get<0>(a) - get<0>(b);
    coord_t y = get<1>(a) - get<1>(b);
    coord_t z = get<2>(a) - get<2>(b);
    return x*x + y*y + z*z;
  }

  std::istream& getpoint(std::istream& in, point_t& point_out) {    
    using std::tr1::get;
    return (in >> get<0>(point_out) >> get<1>(point_out) >> get<2>(point_out));
  }
}

测量结果显示,大部分时间都花在从文件中读取数组上,实际计算所需的时间要少一个数量级。


6
很好的写作。为了抵消文件读取,我用循环来执行您的Python实现,每次执行100次搜索(每次查找不同的点并仅构建一次kd树)。然而暴力法仍然胜出。这让我感到困惑。但是之后我检查了您的叶子大小,并发现您在这里犯了一个错误——你将叶子大小设置为1000001,这样性能表现不佳。设置叶子大小为10后,kd获胜了(对于100个点,需要35秒至70秒,其中大部分35秒用于构建树和100个检索,每个检索10个点需要1秒)。 - Unreason
4
总的来说,如果你能预先计算kd树,那么它将比暴力搜索快上数倍(更不用说对于真正大型的数据集,如果你有一棵树,你就不必在内存中读取所有数据)。 - Unreason
1
@goran:如果我将叶子大小设置为10,则查询一个点需要大约10秒钟(而不是1秒钟)。如果任务是查询多个(> 10)点,则kd-tree应该胜出,我同意这一点。 - jfs
4
从scipy.spatial中导入cKDTree,它是用Cython编写的,比纯Python实现的KDTree(在我旧的Mac PPC上进行16维查找)快50倍以上。 - denis
希望您对每个解决方案都添加一些说明。对于像我这样不熟悉Python和C++的人来说,这个答案是没有用的! - Hengameh
显示剩余6条评论

20
如果一百万个条目已经在文件中,那么就不需要将它们全部加载到内存中的数据结构中。只需保持一个数组,其中包含迄今为止找到的前十个点,并扫描这一百万个点,在进行扫描时更新您的前十列表即可。
这是点数的 O(n) 时间复杂度。

1
这个方案可以正常运行,但是数组不是最有效的数据存储方式,因为你需要在每一步检查它,或者保持它排序,这可能会很麻烦。David关于min-heap的回答可以为你处理这些问题,但本质上是相同的解决方案。当用户只想要10个点时,这些问题可以忽略不计,但是当用户突然想要最近的100个点时,你就会遇到麻烦。 - Karl
@Karl,在编译器和CPU优化最愚蠢的循环来击败最聪明的算法时,往往让人惊讶。永远不要低估当循环可以在芯片上的RAM中运行时所能获得的加速效果。 - Ian Mercer
1
文件中还没有一百万条记录,你可以选择如何将它们存储在文件中。这种存储方式意味着你也可以预先计算任何伴随的索引结构。Kd-树能胜出是因为它完全不需要读取整个文件,时间复杂度小于O(n)。 - Unreason
1
我已经发布了你的答案的实现 https://dev59.com/cXE95IYBdhLWcg3wCJbk#2486341 (尽管我使用堆而不是线性搜索,这对于任务来说完全没有必要) - jfs
1.11和1.15之间的差别相当小。如果我将其编码为堆,那么以后如果我想做更多的事情,就可以了。当然,我不知道原帖作者想用它来做什么。 - FogleBird
显示剩余2条评论

14

您可以将这些点存储在一棵k维树(kd-tree)中。kd-树是针对最近邻搜索进行优化的(查找距离给定点最近的n个点)。


1
我认为这里需要使用八叉树。 - Gabe
11
构建K-d树所需的复杂度将比执行线性搜索查找最接近的10个点所需的复杂度高。K-d树的真正威力在于对点集执行多次查询时显现出来。 - Nixuz
@gabe:kd-tree 的一大优势是,虽然构建时会有一些内存开销,但一旦构建完成,一个左平衡的 kd-tree 占用的内存空间不会比原始的无结构点列表多。相比之下,八叉树肯定至少需要一些内存开销。 - Boojum
1
kd-tree在实际应用中可能比暴力搜索方法更慢。https://dev59.com/cXE95IYBdhLWcg3wCJbk#2486341 - jfs
2
这是我在面试中会给出的答案。面试官使用不太精确的语言并不罕见,仔细阅读问题背后的含义,这似乎是他们最想要的答案。实际上,如果我是面试官,有人给出了“我会将点按任意顺序存储,并进行线性扫描以找到10个点”的答案,并基于我的不精确措辞来证明这个答案,我会感到相当失望。 - Jason Orendorff
3
@ Jason Orendorff:在技术面试中,我肯定会讨论使用kd树来解决这样的问题;然而,我也会解释为什么对于给定的特定问题,更简单的线性搜索方法不仅在渐近意义下更快,而且实际运行速度也更快。这将展示出对算法复杂度的深刻理解、数据结构知识以及考虑问题的不同解决方案的能力。 - Nixuz

10

我认为这是一个棘手的问题,测试你是否不会过度做事。

考虑人们已经提供的最简单算法:保留十个最佳候选点的表格,并逐个查看所有点。如果找到比任何一个最佳点更接近的点,请替换它。复杂度是多少?嗯,我们必须查看来自文件的每个点一次,计算它的距离(实际上是距离的平方),并与第10个最接近的点进行比较。如果更好,将其插入到10个最佳点表中的适当位置。

那么复杂度是什么?我们查看每个点一次,所以需要进行n次距离计算和n次比较。如果该点更好,我们需要将其插入正确的位置,这需要一些额外的比较,但由于最佳候选点的表格具有恒定大小10,因此它是一个常数因子。

最终得到的算法在点数为n时以线性时间运行,O(n)。

但是现在考虑这样一个算法的下限是什么?如果输入数据没有顺序,我们必须查看每个点,以确定其是否是最接近的点之一。因此,据我所见,下限是Ω(n),因此上述算法是最优的。


1
非常好的观点!由于必须逐个读取文件以构建任何数据结构,因此您的最低可能是O(n),就像您所说的那样。只有在问题要求反复查找最近的10个点时,其他才有意义!而且我认为您解释得很好。 - Zan Lynx

6
无需计算距离,只需要计算距离的平方即可满足您的需求。我认为这样会更快。换句话说,您可以跳过 `sqrt` 部分。

4
这不是一个作业问题,是吗?;-)
我的想法是:遍历所有点,并将它们放入最小堆或有限优先队列中,按距离目标的距离进行键控。

1
当然可以,但不清楚目标是什么。 :) - Unreason

4
这个问题本质上测试你对空间划分算法的知识和/或直觉。我认为将数据存储在八叉树中是最好的选择。它通常用于处理这种问题的3D引擎(存储数百万个顶点,射线跟踪,查找碰撞等)。在最坏情况下,查找时间将达到log(n)的数量级(我相信)。

2

对于任意两点P1(x1,y1,z1)和P2(x2,y2,z2),如果两点之间的距离是d,则以下所有内容都必须为真:

|x1 - x2| <= d 
|y1 - y2| <= d
|z1 - z2| <= d

在遍历整个集合时,保留10个最接近的点,并且还要保留到第10个最近点的距离。在查看每个点之前,使用这三个条件可以节省大量复杂度,然后再计算距离。


2

直接的算法:

将点存储为元组列表,并扫描这些点,计算距离并保持“最接近”列表。

更具创意:

将点分组成区域(例如由“0,0,0”到“50,50,50”或“0,0,0”到“-20,-20,-20”描述的立方体),因此您可以从目标点进行“索引”。检查目标所在的立方体,仅搜索该立方体中的点。如果该立方体中的点少于10个,则检查“相邻”的立方体,以此类推。

经过进一步思考,这不是一个很好的算法:如果您的目标点比10个点更靠近立方体的墙壁,则必须搜索相邻的立方体。

我会采用kd-tree方法来找到最接近的节点,然后删除(或标记)该最接近的节点,并重新搜索新的最接近的节点。反复进行此操作。


1

基本上是前面两个答案的结合。由于点在文件中,因此不需要将它们保留在内存中。我不会使用数组或最小堆,而会使用最大堆,因为您只想检查距离小于第10个最近点的点。对于数组,您需要将每个新计算的距离与保留的所有10个距离进行比较。对于最小堆,您必须对每个新计算的距离执行3次比较。对于最大堆,当新计算的距离大于根节点时,您只需要执行1次比较。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接