如何在批处理维度上广播numpy索引?

5
例如,np.array([[1,2],[3,4]])[np.triu_indices(2)] 的形状为(3,),是上三角元素展平后的列表。但是,如果我有一个2x2矩阵的批量:
foo = np.repeat(np.array([[[1,2],[3,4]]]), 30, axis=0)

如果我想获取每个矩阵的上三角索引,尝试的朴素方法是:

foo[:,np.triu_indices(2)]

然而,这个对象实际上是形状为(30,2,3,2)的(而不是如果我们按批次提取上三角条目时可能期望的(30,3))。
我们如何在批次维度上广播元组索引?

“of each matrix” 究竟是什么意思?你只有一个甚至不是方形的矩阵!如果你想要每个2x2子数组(轴1和2)的上三角形,你可以这样做:x, y = np.triu_indices(2); foo[:,x,y] - Mazdak
1个回答

4
获取元组,并使用它们来索引到最后两个维度。
r,c = np.triu_indices(2)
out = foo[:,r,c]

或者,一个适用于3D2D数组的一行代码解决方案是使用省略号

foo[(Ellipsis,)+np.triu_indices(2)]

它同样适用于2D数组 -

out = foo[r,c] # foo as 2D input array

遮罩方法

3D 数组情况

我们也可以使用一个遮罩进行基于遮罩的方式 -

foo[:,~np.tri(2,k=-1, dtype=bool)]

二维数组案例

foo[~np.tri(2,k=-1, dtype=bool)]

似乎我们不能直接同时解包和索引它们。否则,对于维度大于2的情况,这将只是一个单一的步骤。 - kmario23
1
@kmario23 这里有一个带有 Ellipsis 的。 - Divakar
2
Jaaa,谢谢!为了让它更易懂,我们也可以写成:foo[(..., *np.triu_indices(2))] - kmario23
1
没错,那也可以。请注意这是针对Python 3.x的。谢谢。 - Divakar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接