更快的numpy花式索引和缩减?

11
我正在尝试使用和加速花式索引来“连接”两个数组,并对结果的一个轴求和。
类似于这样:
$ ipython
In [1]: import numpy as np
In [2]: ne, ds = 12, 6
In [3]: i = np.random.randn(ne, ds).astype('float32')
In [4]: t = np.random.randint(0, ds, size=(1e5, ne)).astype('uint8')

In [5]: %timeit i[np.arange(ne), t].sum(-1)
10 loops, best of 3: 44 ms per loop

有没有简单的方法来加速In [5]里的语句?我应该使用OpenMP和类似scipy.weaveCythonprange吗?


另一个相关的问题是如何使用 pandas 来完成同样的事情? - npinto
Numpy以C速度执行,因此您可能无法使用weave大幅加速它。 - reptilicus
1个回答

8

numpy.take比花式索引快得多,原因是它将数组视为平坦的。

In [1]: a = np.random.randn(12,6).astype(np.float32)

In [2]: c = np.random.randint(0,6,size=(1e5,12)).astype(np.uint8)

In [3]: r = np.arange(12)

In [4]: %timeit a[r,c].sum(-1)
10 loops, best of 3: 46.7 ms per loop

In [5]: rr, cc = np.broadcast_arrays(r,c)

In [6]: flat_index = rr*a.shape[1] + cc

In [7]: %timeit a.take(flat_index).sum(-1)
100 loops, best of 3: 5.5 ms per loop

In [8]: (a.take(flat_index).sum(-1) == a[r,c].sum(-1)).all()
Out[8]: True

我认为除此之外,要想获得更大的速度提升,唯一的途径就是使用类似于PyCUDA这样的工具编写GPU自定义内核。


1
默认情况下,它只将数组视为平面的,但您仍然可以使用“axis”关键字。例如,np.take(np.arange(10).reshape((-1,2)), [0], axis=0) 将选择第一行。 - jorgeca
@jorgeca:没错,但是我认为除非你索引平坦数组,否则你不能通过同时指定行和列来提取单个元素,就像使用高级索引一样。 - user545424

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接