高效地对数组进行双重迭代

3
我有以下代码,其中points是一个由3列列表组成的多行列表,coorRadius是我想要找到局部坐标最大值的半径,localCoordinateMaxima是一个数组,我在其中存储这些最大值的i值:
for i,x in enumerate(points):
        check = 1
        for j,y in enumerate(points):
            if linalg.norm(x-y) <= coorRadius and x[2] < y[2]:
                check = 0

        if check == 1:
            localCoordinateMaxima.append(i)

    print localCoordinateMaxima

很不幸,当我有几千个点时,这需要很长时间,我正在寻找加快速度的方法。我尝试使用if all()条件来完成它,但是我没有成功,我甚至不确定它是否更有效。你们能否提出一种使它更快的方法呢?

最好的!


points 是一个 numpy 数组吗?这些数据点是否在网格上? - Roland Smith
points是一个列表的列表,其中坐标以浮点格式表示,因此它们不完全在网格上。 - John
我因XCOM太累了,无法写出一个恰当的答案,但是http://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.spatial.cKDTree.html。 - user2357112
链接不幸地无法加载。 - John
是的,它似乎现在已经崩溃了。当SciPy的网站无法使用时,您可以在Wikipedia上阅读有关k-d树的内容。 - user2357112
显示剩余2条评论
3个回答

2

这里是您的代码经过一些精简后的版本:

for i, x in enumerate(points):
    x2 = x[2]
    for y in points:
        if linalg.norm(x-y) <= coorRadius and x2 < y[2]:
            break
    else:
        localCoordinateMaxima.append(i)
    print localCoordinateMaxima

变更:

  • x[2]查找拆分出来。
  • j 变量未使用。
  • 添加break以提前退出。
  • 使用for-else结构而非标志变量。

谢谢,我是Python的新手,x2 = x[2]会让它更快吗,还是只是为了整洁? - John
它可以加快速度。x[2]查找需要时间,因此最好将循环不变代码提取出来。 - Raymond Hettinger

2

使用KDTree进行邻居搜索是最佳选择。

from scipy.spatial import cKDTree

tree = cKDTree(points)
pairs = tree.query_pairs(coorRadius)

现在pairs是一个包含两个元素的元组集合(i, j),其中i < jpoints[i]points[j]相互之间的距离小于coorRadius。您现在可以简单地对这些元组进行迭代,这将比您目前正在迭代的len(points)**2要少得多。

is_maximum = [True] * len(points)
for i, j in pairs:
    if points[i][2] < points[j][2]:
        is_maximum[i] = False
    elif points[j][2] < points[i][2]:
        is_maximum[j] = False
localCoordinateMaxima, = np.nonzero(is_maximum)

这可以通过向量化进一步加速:
pairs = np.array(list(pairs))
pairs = np.vstack((pairs, pairs[:, ::-1]))
pairs = pairs[np.argsort(pairs[:, 0])]
is_z_smaller = points[pairs[:, 0], 2] < points[pairs[:, 1], 2]
bins, = np.nonzero(pairs[:-1, 0] != pairs[1:, 0])
bins = np.concatenate(([0], bins+1))
is_maximum = np.logical_and.reduceat(is_z_smaller, bins)
localCoordinateMaxima, = np.nonzero(is_maximum)

上面的代码假设每个点至少有一个相邻点在coorRadius范围内。如果不是这种情况,你需要稍微复杂化一下:
pairs = np.array(list(pairs))
pairs = np.vstack((pairs, pairs[:, ::-1]))
pairs = pairs[np.argsort(pairs[:, 0])]
is_z_smaller = points[pairs[:, 0], 2] < points[pairs[:, 1], 2]
bins, = np.nonzero(pairs[:-1, 0] != pairs[1:, 0])
has_neighbors = pairs[np.concatenate(([True], bins)), 0]
bins = np.concatenate(([0], bins+1))
is_maximum = np.ones((len(points),), bool)
is_maximum[has_neighbors] &= np.logical_and.reduceat(is_z_smaller, bins)
localCoordinateMaxima, = np.nonzero(is_maximum)

1
使用numpy并不太难。如果您愿意,您可以使用一个(长)表达式来完成这项任务:
import numpy as np

points = np.array(points)
localCoordinateMaxima = np.where(np.all((np.linalg.norm(points-points[None,:], axis=-1) >
                                         coorRadius) |
                                        (points[:,2] >= points[:,None,2]),
                                        axis=-1))

你当前代码实现的算法本质上是where(not(any(w <= x and y < z)))。如果你通过德摩根定律将not分配到其中的逻辑操作中,你可以翻转不等式并避免一层嵌套,得到where(all(w > x or y >= z)))w是应用于点之间差异的规范矩阵。 x是一个常数。 yz都是数组,具有点的第三个坐标,形状使它们可以广播到与w相同的形状。

虽然不太易读,但这是两个答案中更为高效的一个。 - Dunes
我认为“高效”不是正确的形容词。这比 Raymond 的回答中提前跳出循环的方法效率要低,但对于较小的数据集来说可能更快。 - Jaime
谢谢,看起来很聪明,但是这并不能得出与我的代码相同的答案。 - John
我的数据集通常有几千个数据点,不确定您所说的“更小”是什么意思。 - John

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接