我的numpy数组切片代码(通过花式索引)非常慢。它目前是程序的瓶颈。
a.shape
(3218, 6)
ts = time.time(); a[rows][:, cols]; te = time.time(); print('%.8f' % (te-ts));
0.00200009
如何使用正确的numpy函数获取矩阵a的行子集'rows'和列子集'col'组成的数组?(实际上,我需要这个结果的转置)
让我试着总结一下Jaime和TheodrosZelleke的优秀回答并加入一些评论。
高级(花式)索引总是返回副本,而不是视图。
a[rows][:,cols]
意味着两个花式索引操作,因此会创建并且丢弃中间的副本a[rows]
。虽然使用方便和可读性好,但效率不高。另外要注意的是,[:,cols]
通常从一个C-cont.源生成一个Fortran连续副本。
a[rows.reshape(-1,1),cols]
是单个高级索引表达式,基于rows.reshape(-1,1)
和cols
被broadcast到预期结果的形状上。
一个常见经验是在扁平化数组中进行索引可以比花式索引更有效率,因此另一种方法是
indx = rows.reshape(-1,1)*a.shape[1] + cols
a.take(indx)
或者
a.take(indx.flat).reshape(rows.size,cols.size)
效率将取决于内存访问模式以及起始数组是C连续还是Fortran连续,因此需要进行实验。
仅在真正需要时使用高级索引:基本切片 a[rstart:rstop:rstep, cstart:cstop:cstep]
返回一个视图(虽然不连续),应该更快!
令我惊讶的是,这种计算第一个线性1D索引的冗长表达式比问题中提出的连续数组索引快了50%以上:
(a.ravel()[(
cols + (rows * a.shape[1]).reshape((-1,1))
).ravel()]).reshape(rows.size, cols.size)
更新:原帖作者已更新初始数组的形状描述。根据更新后的尺寸,加速比现在已经超过99%:
In [93]: a = np.random.randn(3218, 1415)
In [94]: rows = np.random.randint(a.shape[0], size=2000)
In [95]: cols = np.random.randint(a.shape[1], size=6)
In [96]: timeit a[rows][:, cols]
10 loops, best of 3: 186 ms per loop
In [97]: timeit (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
1000 loops, best of 3: 1.56 ms per loop
初步答案: 以下是文字转录:
In [79]: a = np.random.randn(3218, 6)
In [80]: a.shape
Out[80]: (3218, 6)
In [81]: rows = np.random.randint(a.shape[0], size=2000)
In [82]: cols = np.array([1,3,4,5])
时间方法1:
In [83]: timeit a[rows][:, cols]
1000 loops, best of 3: 1.26 ms per loop
时间方法2:
In [84]: timeit (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
1000 loops, best of 3: 568 us per loop
检查结果是否真的相同:
In [85]: result1 = a[rows][:, cols]
In [86]: result2 = (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
In [87]: np.sum(result1 - result2)
Out[87]: 0.0
如果您使用高级索引和广播,可以加快切片速度:
from __future__ import division
import numpy as np
def slice_1(a, rs, cs) :
return a[rs][:, cs]
def slice_2(a, rs, cs) :
return a[rs[:, None], cs]
>>> rows, cols = 3218, 6
>>> rs = np.unique(np.random.randint(0, rows, size=(rows//2,)))
>>> cs = np.unique(np.random.randint(0, cols, size=(cols//2,)))
>>> a = np.random.rand(rows, cols)
>>> import timeit
>>> print timeit.timeit('slice_1(a, rs, cs)',
'from __main__ import slice_1, a, rs, cs',
number=1000)
0.24083110865
>>> print timeit.timeit('slice_2(a, rs, cs)',
'from __main__ import slice_2, a, rs, cs',
number=1000)
0.206566124519
如果你从百分比的角度来看,做某事快15%总是不错的,但对于你的数组大小,在我的系统中,这个切片操作只需要40微秒,很难相信一个花费240微秒的操作会成为瓶颈。
使用np.ix_
,您可以获得与ravel/reshape类似的速度,但代码更加清晰:
a = np.random.randn(3218, 1415)
rows = np.random.randint(a.shape[0], size=2000)
cols = np.random.randint(a.shape[1], size=6)
a = np.random.randn(3218, 1415)
rows = np.random.randint(a.shape[0], size=2000)
cols = np.random.randint(a.shape[1], size=6)
%timeit (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
#101 µs ± 2.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit ix_ = np.ix_(rows, cols); a[ix_]
#135 µs ± 7.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
ix_ = np.ix_(rows, cols)
result1 = a[ix_]
result2 = (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
np.sum(result1 - result2)
0.0
time.time
不是衡量时间的好方法。通常情况下,最好使用timeit
代替。 - mgilson@mgilson
样式的提及,它将向用户发送通知(每个评论一个)。 - agf