Numpy: 从点列表中获取最大值的正确方法

6
我有一个三维坐标系(X,Y,Z)中点的列表。此外,每个点都被分配了一个浮点值v,因此单个点可以描述为(x,y,z,v)。该列表表示为形状为(N,4)的numpy数组。对于每个2d位置x,y,我需要获取v的最大值。一种直接但计算量昂贵的方法是:
for index in range(points.shape[0]):
    x = points[index, 0]
    y = points[index, 1]
    v = points[index, 3]

    maxes[x, y] = np.max(maxes[x, y], v)

有没有更多“numpy”方法,能够在性能方面带来一些收益?

为什么 np.max(points[:, 3]) 不足够? - Cedric H.
1
@CedricH。他们想要知道每个x, y对的最大值。基本上,按x, y分组并找到每个组的最大值。 - Giacomo Alzetta
@SVengat 是的,不过我想知道numpy是否能以某种方式将其向量化。 - thatsme
你能发一下你的列表样例吗?它们是元组吗?还是只是数组? - user3483203
@user3483203 它是一个numpy数组,形状为[N,4],如问题描述中所述。我写了“元组”,但实际上并不是指Python中的元组。如果不够清楚,很抱歉。 - thatsme
显示剩余2条评论
4个回答

4

设置

points = np.array([[ 0,  0,  1,  1],
                   [ 0,  0,  2,  2],
                   [ 1,  0,  3,  0],
                   [ 1,  0,  4,  1],
                   [ 0,  1,  5, 10]])

这里的一般想法是使用第一、二和四列进行排序,并反转结果,这样当我们找到唯一值时,具有第四列最大值的值将位于具有类似x和y坐标的其他值上方。然后我们使用np.unique在第一和第二列中查找唯一值,并返回这些结果,这些结果将具有最大的v

使用lexsortnumpy.unique

def max_xy(a):
    res = a[np.lexsort([a[:, 3], a[:, 1], a[:, 0]])[::-1]]
    vals, idx = np.unique(res[:, :2], 1, axis=0)
    maximums = res[idx]
    return maximums[:, [0,1,3]]

array([[ 0,  0,  2],
       [ 0,  1, 10],
       [ 1,  0,  1]])

避免使用unique以提高性能
def max_xy_v2(a):
    res = a[np.lexsort([a[:, 3], a[:, 1], a[:, 0]])[::-1]]
    res = res[np.append([True], np.any(np.diff(res[:, :2],axis=0),1))]
    return res[:, [0,1,3]]

max_xy_v2(points)

array([[ 1,  0,  1],
       [ 0,  1, 10],
       [ 0,  0,  2]])

请注意,虽然两者都会返回正确的结果,但它们不会按原始列表排序,如果您喜欢,可以在结尾处添加另一个lexsort来解决这个问题。


1
第二个解决方案很棒!非常感谢你!:) - thatsme

2

抱歉,这也不是纯“numpy”解决方案,但numpy_indexed包提供了一种非常方便(且快速)的方法来完成此操作。

import numpy_indexed as npi
npi.group_by(points[:, 0:2]).max(points[:,3])

与其他方法的比较

%timeit npi.group_by(points[:, 0:2]).max(points[:,3])
58 µs ± 435 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


%timeit pd.DataFrame(points, columns=['X', 'Y', 'Z', 'V']).groupby(by=['X', 'Y']).apply(lambda r: r['V'].max()).reset_index().values
3.15 ms ± 36.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

def max_xy_ver1(a):
    res = a[np.lexsort([a[:, 0], a[:, 1], a[:, 3]])[::-1]]
    vals, idx = np.unique(res[:, :2], 1, axis=0)
    maximums = res[idx]
    return maximums[:, [0,1,3]]

%timeit max_xy_ver1(points)
63.5 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

def max_xy_ver2(a):
    res = a[np.lexsort([a[:, 3], a[:, 1], a[:, 0]])[::-1]]
    res = res[np.append([True], np.any(np.diff(res[:, :2],axis=0),1))]
    return res[:, [0,1,3]]

%timeit_max_xy_ver2(points) # current winner
31.7 µs ± 524 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

def findmaxes_simple(points):
    maxes = {}
    for index in range(points.shape[0]):
        x = points[index, 0]
        y = points[index, 1]
        v = points[index, 3]
        maxes[(x, y)] = v if (x,y) not in maxes else max(maxes[(x, y)],v)
    return maxes

%timeit findmaxes_simple(points)
82.6 µs ± 632 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

通过Pip安装numpy_indexed

pip install --user numpy_indexed

(如果您使用的是Ubuntu或其他Linux发行版,您可能需要使用pip3来为Python 3安装软件包)

用于测试的数据

在此处可以找到Pastebin链接


能否在我的更新解决方案中添加时间?应该会快得多。 - user3483203
@user3483203 当然没问题,1秒钟就好 :) - Greg Kramida

1

这不是纯粹的 numpy,我利用了 pandas,我相信它会尽力进行向量化:

a = [
    [0, 0, 1, 1],
    [0, 0, 2, 2],
    [1, 0, 3, 0],
    [1, 0, 4, 1],
    [0, 1, 5, 10],
]
pd.DataFrame(a, columns=['X', 'Y', 'Z', 'V']).groupby(by=['X', 'Y']).apply(lambda r: r['V'].max()).reset_index().values

将其翻译成中文为:

返回这个:


array([[ 0,  0,  2],
       [ 0,  1, 10],
       [ 1,  0,  1]])

这是一个向量化版本的答案:pd.DataFrame(a, columns=['X', 'Y', 'Z', 'V']).groupby(by=['X', 'Y']).V.max().reset_index().values - user3483203

0

使用纯numpy:

import numpy as np

points = np.array([(1,2,3,4),
                   (2,3,5,6),
                   (1,2,9,8)])  #an example,

def find_vmax(x, y) :
    xpoints = points[np.where( points[:,0] == x)[0]]
    xypoints = xpoints[np.where( xpoints[:,1] == y)[0]]
    return np.max(xypoints[:, 3])

print(find_vmax(1, 2))

1
这仅提供单个(x,y)组合的结果。您必须使用np.unique获取所有结果,并通过标准for循环迭代所有结果以获得完整解决方案。 - Greg Kramida

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接