百万个点是一个较小的数字。最直接的方法在这里起作用(基于KDTree的代码对于仅查询一个点而言速度较慢)。
暴力算法(时间约为1秒)
import numpy
NDIM = 3
a = numpy.fromfile('million_3D_points.txt', sep=' ')
a.shape = a.size / NDIM, NDIM
point = numpy.random.uniform(0, 100, NDIM)
print 'point:', point
d = ((a-point)**2).sum(axis=1)
ndx = d.argsort()
import pprint
pprint.pprint(zip(a[ndx[:10]], d[ndx[:10]]))
运行它:
$ time python nearest.py
point: [ 69.06310224 2.23409409 50.41979143]
[(array([ 69., 2., 50.]), 0.23500677815852947),
(array([ 69., 2., 51.]), 0.39542392750839772),
(array([ 69., 3., 50.]), 0.76681859086988302),
(array([ 69., 3., 50.]), 0.76681859086988302),
(array([ 69., 3., 51.]), 0.9272357402197513),
(array([ 70., 2., 50.]), 1.1088022980015722),
(array([ 70., 2., 51.]), 1.2692194473514404),
(array([ 70., 2., 51.]), 1.2692194473514404),
(array([ 70., 3., 51.]), 1.801031260062794),
(array([ 69., 1., 51.]), 1.8636121147970444)]
real 0m1.122s
user 0m1.010s
sys 0m0.120s
以下是生成100万个3D点的脚本:
这里是生成100万个3D点的脚本:
import random
for _ in xrange(10**6):
print ' '.join(str(random.randrange(100)) for _ in range(3))
输出:
$ head million_3D_points.txt
18 56 26
19 35 74
47 43 71
82 63 28
43 82 0
34 40 16
75 85 69
88 58 3
0 63 90
81 78 98
您可以使用该代码测试更复杂的数据结构和算法(例如,它们是否比上述最简单的方法实际消耗更少的内存或更快)。值得注意的是,目前它是唯一包含可工作代码的答案。
基于KDTree的解决方案 (时间约为1.4秒)
import numpy
NDIM = 3
a = numpy.fromfile('million_3D_points.txt', sep=' ')
a.shape = a.size / NDIM, NDIM
point = [ 69.06310224, 2.23409409, 50.41979143]
print 'point:', point
from scipy.spatial import KDTree
tree = KDTree(a, leafsize=a.shape[0]+1)
distances, ndx = tree.query([point], k=10)
print a[ndx]
运行它:
$ time python nearest_kdtree.py
point: [69.063102240000006, 2.2340940900000001, 50.419791429999997]
[[[ 69. 2. 50.]
[ 69. 2. 51.]
[ 69. 3. 50.]
[ 69. 3. 50.]
[ 69. 3. 51.]
[ 70. 2. 50.]
[ 70. 2. 51.]
[ 70. 2. 51.]
[ 70. 3. 51.]
[ 69. 1. 51.]]]
real 0m1.359s
user 0m1.280s
sys 0m0.080s
C++中的部分排序(时间约为1.1秒)
#include <algorithm>
#include <iostream>
#include <vector>
#include <boost/lambda/lambda.hpp>
#include <boost/lambda/bind.hpp>
#include <boost/tuple/tuple_io.hpp>
namespace {
typedef double coord_t;
typedef boost::tuple<coord_t,coord_t,coord_t> point_t;
coord_t distance_sq(const point_t& a, const point_t& b) {
coord_t x = a.get<0>() - b.get<0>();
coord_t y = a.get<1>() - b.get<1>();
coord_t z = a.get<2>() - b.get<2>();
return x*x + y*y + z*z;
}
}
int main() {
using namespace std;
using namespace boost::lambda;
vector<point_t> points;
cin.exceptions(ios::badbit);
while(cin) {
coord_t x,y,z;
cin >> x >> y >> z;
points.push_back(boost::make_tuple(x,y,z));
}
point_t point(69.06310224, 2.23409409, 50.41979143);
cout << "point: " << point << endl;
const size_t m = 10;
partial_sort(points.begin(), points.begin() + m, points.end(),
bind(less<coord_t>(),
bind(distance_sq, _1, point),
bind(distance_sq, _2, point)));
for_each(points.begin(), points.begin() + m, cout << _1 << "\n");
}
运行它:
g++ -O3 nearest.cc && (time ./a.out < million_3D_points.txt )
point: (69.0631 2.23409 50.4198)
(69 2 50)
(69 2 51)
(69 3 50)
(69 3 50)
(69 3 51)
(70 2 50)
(70 2 51)
(70 2 51)
(70 3 51)
(69 1 51)
real 0m1.152s
user 0m1.140s
sys 0m0.010s
C++中的优先队列(时间约为1.2秒)
#include <algorithm>
#include <functional>
#include <iostream>
#include <boost/range.hpp>
#include <boost/tr1/tuple.hpp>
namespace {
typedef double coord_t;
typedef std::tr1::tuple<coord_t,coord_t,coord_t> point_t;
coord_t distance_sq(const point_t& a, const point_t& b) {
using std::tr1::get;
coord_t x = get<0>(a) - get<0>(b);
coord_t y = get<1>(a) - get<1>(b);
coord_t z = get<2>(a) - get<2>(b);
return x*x + y*y + z*z;
}
std::istream& getpoint(std::istream& in, point_t& point_out) {
using std::tr1::get;
return (in >> get<0>(point_out) >> get<1>(point_out) >> get<2>(point_out));
}
template<class T>
class less_distance : public std::binary_function<T, T, bool> {
const T& point;
public:
less_distance(const T& reference_point) : point(reference_point) {}
bool operator () (const T& a, const T& b) const {
return distance_sq(a, point) < distance_sq(b, point);
}
};
}
int main() {
using namespace std;
point_t point(69.06310224, 2.23409409, 50.41979143);
cout << "point: " << point << endl;
const size_t nneighbours = 10;
point_t points[nneighbours+1];
for (size_t i = 0; getpoint(cin, points[i]) && i < nneighbours; ++i)
;
less_distance<point_t> less_distance_point(point);
make_heap (boost::begin(points), boost::end(points), less_distance_point);
while(getpoint(cin, points[nneighbours])) {
push_heap(boost::begin(points), boost::end(points), less_distance_point);
pop_heap (boost::begin(points), boost::end(points), less_distance_point);
}
push_heap (boost::begin(points), boost::end(points), less_distance_point);
sort_heap (boost::begin(points), boost::end(points), less_distance_point);
for (size_t i = 0; i < nneighbours; ++i) {
cout << points[i] << ' ' << distance_sq(points[i], point) << '\n';
}
}
运行它:
$ g++ -O3 nearest.cc && (time ./a.out < million_3D_points.txt )
point: (69.0631 2.23409 50.4198)
(69 2 50) 0.235007
(69 2 51) 0.395424
(69 3 50) 0.766819
(69 3 50) 0.766819
(69 3 51) 0.927236
(70 2 50) 1.1088
(70 2 51) 1.26922
(70 2 51) 1.26922
(70 3 51) 1.80103
(69 1 51) 1.86361
real 0m1.174s
user 0m1.180s
sys 0m0.000s
基于线性搜索的方法(时间约为1.15秒)
#include <algorithm>
#include <functional>
#include <iostream>
#include <boost/foreach.hpp>
#include <boost/range.hpp>
#include <boost/tr1/tuple.hpp>
#define foreach BOOST_FOREACH
namespace {
typedef double coord_t;
typedef std::tr1::tuple<coord_t,coord_t,coord_t> point_t;
coord_t distance_sq(const point_t& a, const point_t& b);
std::istream& getpoint(std::istream& in, point_t& point_out);
class less_distance : public std::binary_function<point_t, point_t, bool> {
const point_t& point;
public:
explicit less_distance(const point_t& reference_point)
: point(reference_point) {}
bool operator () (const point_t& a, const point_t& b) const {
return distance_sq(a, point) < distance_sq(b, point);
}
};
}
int main() {
using namespace std;
point_t point(69.06310224, 2.23409409, 50.41979143);
cout << "point: " << point << endl;
less_distance nearer(point);
const size_t nneighbours = 10;
point_t points[nneighbours];
foreach (point_t& p, points)
if (! getpoint(cin, p))
break;
point_t current_point;
while(cin) {
getpoint(cin, current_point);
foreach (point_t& p, points)
if (nearer(current_point, p))
std::swap(current_point, p);
}
sort(boost::begin(points), boost::end(points), nearer);
foreach (point_t p, points)
cout << p << ' ' << distance_sq(p, point) << '\n';
}
namespace {
coord_t distance_sq(const point_t& a, const point_t& b) {
using std::tr1::get;
coord_t x = get<0>(a) - get<0>(b);
coord_t y = get<1>(a) - get<1>(b);
coord_t z = get<2>(a) - get<2>(b);
return x*x + y*y + z*z;
}
std::istream& getpoint(std::istream& in, point_t& point_out) {
using std::tr1::get;
return (in >> get<0>(point_out) >> get<1>(point_out) >> get<2>(point_out));
}
}
测量结果显示,大部分时间都花在从文件中读取数组上,实际计算所需的时间要少一个数量级。