基于索引对NumPy数组进行掩码

7
如何根据实际索引值屏蔽数组?
也就是说,如果我有一个10x10x30的矩阵,并且当第一和第二个索引相等时,我希望掩盖数组。
例如,[1, 1 , :] 将被掩盖,因为1和1相等,但[1, 2, :]不会被掩盖,因为它们不相等。
我只在第三维中询问这个问题,因为它类似于我的当前问题并可能使事情复杂化。但我的主要问题是,如何根据索引的值来屏蔽数组?

6
离题了,你的头像很好地匹配了你的问题主题 :P - askewchan
3
是的,我只是试图在我的代码上画Z字。谢谢你帮我做到了这一点。 - aleph4
2个回答

7

通常,要访问索引的值,可以使用 np.meshgrid:

i, j, k = np.meshgrid(*map(np.arange, m.shape), indexing='ij')
m.mask = (i == j)

这种方法的优点在于它适用于ijk上的任意布尔函数。相较于使用“identity”特殊情况,速度稍慢一些。
In [56]: %%timeit
   ....: i, j, k = np.meshgrid(*map(np.arange, m.shape), indexing='ij')
   ....: i == j
10000 loops, best of 3: 96.8 µs per loop

正如@Jaime所指出的,meshgrid支持sparse选项,它并不会做太多重复的工作,但在某些情况下需要更加小心,因为它们无法广播。这将节省内存并略微提高速度。例如,

In [77]: x = np.arange(5)

In [78]: np.meshgrid(x, x)
Out[78]: 
[array([[0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4]]),
 array([[0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2],
       [3, 3, 3, 3, 3],
       [4, 4, 4, 4, 4]])]

In [79]: np.meshgrid(x, x, sparse=True)
Out[79]: 
[array([[0, 1, 2, 3, 4]]),
 array([[0],
       [1],
       [2],
       [3],
       [4]])]

所以,你可以像他说的那样使用“sparse”版本,但必须强制进行广播,如下所示:
i, j, k = np.meshgrid(*map(np.arange, m.shape), indexing='ij', sparse=True)
m.mask = np.repeat(i==j, k.size, axis=2)

同时加速:

In [84]: %%timeit
   ....: i, j, k = np.meshgrid(*map(np.arange, m.shape), indexing='ij', sparse=True)
   ....: np.repeat(i==j, k.size, axis=2)
10000 loops, best of 3: 73.9 µs per loop

1
为什么不使用np.arange而不是range,这样可以节省时间。 - Daniel
3
使用 map 很好,但是您创建了三个完整的 (10, 10, 30) 掩码...如果在调用np.meshgrid时设置 sparse=True,则可以大大减少内存占用,但需要在掩码的创建中包括所有维度,例如 (i == j) & (k >= 0) 的运行速度是您解决方案的两倍。 - Jaime
谢谢@Jaime,我一直在尝试用更简洁的版本来完成它,但是无法做到! - askewchan
1
很遗憾布尔遮罩索引不支持广播。可能有一些好的理由,但它将使某些操作变得非常容易。 - Jaime
1
sparse=True 在这种情况下似乎会带来很多麻烦,至少需要8微秒的时间 :)。 - Bi Rico
@BiRico 是的,只是为了完整性而已。老实说,如果时间紧迫,使用 np.identity 的答案也是最好的。 - askewchan

0
在您想要掩盖对角线的特殊情况下,您可以使用np.identity()函数,该函数返回沿对角线的值为1的矩阵。由于您有第三个维度,因此我们必须将该第三个维度添加到单位矩阵中:
m.mask = np.identity(10)[...,None]*np.ones((1,1,30))

构造该数组可能有更好的方法,但基本上是将30个np.identity(10)数组堆叠起来。例如,以下代码是等效的:

np.dstack((np.identity(10),)*30)

但速度较慢:

In [30]: timeit np.identity(10)[...,None]*np.ones((1,1,30))
10000 loops, best of 3: 40.7 µs per loop

In [31]: timeit np.dstack((np.identity(10),)*30)
1000 loops, best of 3: 219 µs per loop

还有@Ophion的建议

In [33]: timeit np.tile(np.identity(10)[...,None], 30)
10000 loops, best of 3: 63.2 µs per loop

In [71]: timeit np.repeat(np.identity(10)[...,None], 30)
10000 loops, best of 3: 45.3 µs per loop

np.tile(np.identity(10)[...,None], 30) - Daniel
1
很有趣。tile 基本上是 repeat 的一个包装器。因此,如果您运行 np.repeat(np.identity(10)[..., None], 30, axis=-1),则可以跳过一些额外的 if 语句,但对于任何在 1 维中的数组大小,np.identity(10)[..., None]*np.ones((1,1,30)) 是普遍更快的。知道这点很好。 - Daniel
@Ophion 酷,我从未考虑过在长度为1的维度上,tilerepeat 是等效的。 - askewchan
1
你需要指定一个轴,因为np.repeat与大多数函数的默认值不同,其默认轴是axis=-1,而不是axis=None。如果你没有显式传递轴关键字,你将得到一个一维数组,并且必须手动重塑它。 - Daniel

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接