NumPy ndarray的动态轴索引

12

我想获取一个三维数组中给定方向的二维切片,其中direction(或要提取切片的轴)由另一个变量给出。

假设idx是三维数组中二维切片的索引,direction是获取该二维切片的轴,则最初的方法是:

if direction == 0:
    return A[idx, :, :]
elif direction == 1:
    return A[:, idx, :]
else:
    return A[:, :, idx]

我相信一定有一种方法可以在不使用条件语句的情况下完成这项任务,或者至少不是在原始的Python中。NumPy是否有一个快捷方式来完成这个呢?

到目前为止我找到的更好的解决方案(用于动态地执行它),依赖于转置运算符:

# for 3 dimensions [0,1,2] and direction == 1 --> [1, 0, 2]
tr = [direction] + range(A.ndim)
del tr[direction+1]

return np.transpose(A, tr)[idx]

不过我在想是否有更好/更容易/更快的函数可以实现这个功能,因为对于三维来说,转置代码看起来几乎比三个if / elif 还要丑陋。它针对ND进行了更好的概括,N越大,与之相比代码越美观,但对于三维来说基本相同。

3个回答

10

转置是一种非常快捷的操作。有些numpy函数可以使用它将操作轴(或轴)移动到已知位置,通常是形状列表的前面或后面。tensordot就是其中之一。

其他函数会构造一个索引元组。它们可能从一个列表或数组开始,以便进行更容易的操作,然后将其转换为元组以便应用。例如:

I = [slice(None)]*A.ndim
I[axis] = idx
A[tuple(I)]

np.apply_along_axis可以做到类似的功能。查看这些函数的代码是很有启发性的。

我想象一下,numpy函数的编写者们最担心的是它是否能够稳定地工作,其次是速度,最后才是外观是否漂亮。你可以在一个函数中隐藏各种丑陋的代码!


tensordot以此结束

at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = dot(at, bt)
return res.reshape(olda + oldb)

之前的代码计算了newaxes_..newshape...

apply_along_axis构造了一个(0...,:,0...)的索引元组。

i = zeros(nd, 'O')
i[axis] = slice(None, None)
i.put(indlist, ind)
....arr[tuple(i.tolist())]

非常有趣的文章,谢谢!如果我没弄错的话,transpose 只是返回一个更改了 strides 的数组,对吧?除非你调用 .copy(),它不会重新排列内存中的数据。根据几个 %timeit 测试,在 A = np.random.randn(1000,1000)B = A.T 的情况下,访问 A[:, idx] 中的随机列比访问相同数据的 B[idx] 慢。然而,如果我们完全索引 B[idx, :],访问 B 中的数据会变得更加昂贵。这是一种大约为 50 ns(x1.25 速度提升)的情况,所以并不重要,但我很好奇。你知道在 numpy 中这是如何工作的吗? - Imanol Luengo
1
这些都是视图。它不会迭代数据或复制数据。因此,时间差异与解析表达式和创建新的数组对象(相同的数据指针)有关。使用 B[100]A.T[100,:] 时,flags__array_interface__ 看起来是相同的。 - hpaulj

-1

要动态索引一个维度,您可以使用swapaxes,如下所示:

a = np.arange(7 * 8 * 9).reshape((7, 8, 9))

axis = 1
idx = 2

np.swapaxes(a, 0, axis)[idx]

运行时比较

自然方法(非动态):

%timeit a[:, idx, :]
300 ns ± 1.58 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

swapaxes:

%timeit np.swapaxes(a, 0, axis)[idx]
752 ns ± 4.54 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

使用列表推导式的索引:

%timeit a[[idx if i==axis else slice(None) for i in range(a.ndim)]]

swapaxes 不会返回空间上一致的结果。例如,对于一个 [100, 200, 300] 的数组,如果我要索引最后一个轴(例如 data[:, :, 5]),我期望得到一个 [100, 200] 的数组。然而,如果我们执行 data.swapaxes(0, 2)[5],我们得到的是一个 [200, 100] 的数组,因为轴 0 现在在末尾。它对于 [0, 1] 轴是正常工作的,但仅限于这些轴。 - Imanol Luengo
没错,我忽略了那个。选择的是相同的数据,但顺序可能会改变。 - ndou

-1

这是Python代码。你可以像这样简单地使用eval()

def get_by_axis(a, idx, axis):
    indexing_list = a.ndim*[':']
    indexing_list[axis] = str(idx) 
    expression = f"a[{', '.join(indexing_list)}]"
    return eval(expression)

显然,在这种情况下您不接受来自不可信任的用户的输入。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接