通过 np.argpartition 索引更快地对 3D NumPy 数组进行索引。

4

我有一个大的3D NumPy数组:

x = np.random.rand(1_000_000_000).reshape(500, 1000, 2000)

对于这500个二维数组,我需要在每个二维数组的每一列中只保留最大的800个元素。为避免高昂的排序成本,我决定使用np.argpartition

k = 800
idx = np.argpartition(x, -k, axis=1)[:, -k:]
result = x[np.arange(x.shape[0])[:, None, None], idx, np.arange(x.shape[2])]

np.argpartition虽然相当快,但使用idxx中检索回来非常缓慢。有更快(且内存效率更高)的方法执行此索引吗?

请注意,结果不需要按升序排序。它们只需要是前800个即可。


所以 result 是 (500,800,2000),是一份副本,而不是一个视图。 - hpaulj
是的,然后我需要计算result中每个二维数组的某些内容(例如沿axis=1的stddev),但生成result非常缓慢。有没有其他替代方案?我能避免复制吗? - slaw
我尝试了你的方法和排序方法 out1 = np.sort(x, axis=1)[:,-k:],但两个结果并不相同。 - Quang Hoang
@QuangHoang 因为我的结果没有排序(例如,在上面的评论中计算stddev时,我不需要结果被排序),所以结果可能看起来不同。 我只想要每个2D数组中每列的前800个最大值,但顺序并不重要。 - slaw
1个回答

0

为了适应我的内存而将大小减小10,以下是各个步骤的时间:

创建:

In [65]: timeit x = np.random.rand(1_000_000_00).reshape(500, 1000, 200)
1.89 s ± 82 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [66]: x = np.random.rand(1_000_000_00).reshape(500, 1000, 200)
In [67]: k=800

排序:

In [68]: idx = np.argpartition(x, -k, axis=1)[:, -k:]
In [69]: timeit idx = np.argpartition(x, -k, axis=1)[:, -k:]

2.52 s ± 292 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 

索引:
In [70]: result = x[np.arange(x.shape[0])[:, None, None], idx, np.arange(x.shape[2])]
In [71]: timeit result = x[np.arange(x.shape[0])[:, None, None], idx, np.arange(x.shape[2])]
The slowest run took 4.11 times longer than the fastest. This could mean that an intermediate result is being cached.
2.6 s ± 1.87 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

这三个步骤需要大约相同的时间。我没有看到最后一个索引有什么异常。这是0.8 GB。

简单的复制,不进行索引,需要近1秒钟。

In [75]: timeit x.copy()
980 ms ± 231 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

和完整的高级索引副本:

In [77]: timeit x[np.arange(x.shape[0])[:, None, None], np.arange(x.shape[1])[:,
    ...: None], np.arange(x.shape[2])]
1.47 s ± 37.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

再次尝试使用idx

In [78]: timeit result = x[np.arange(x.shape[0])[:, None, None], idx, np.arange(x.shape[2])]
1.71 s ± 42.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

请记住,当操作开始几乎使用所有内存和/或开始需要交换和特殊内存请求到操作系统时,计时可能会真正变糟。
编辑
您不需要两步过程。 只需使用 `partition` :
out = np.partition(x,-k,axis = 1) [:,-k:]
这与 `result` 相同,并且需要与 `idx` 步骤相同的时间。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接