NumPy数组的条件操作

11

我是一个 NumPy 新手,我在运行一些对 NumPy 数组的条件语句时遇到了问题。假设我有三个类似于以下内容的 NumPy 数组:

a:

[[0, 4, 4, 2],
 [1, 3, 0, 2],
 [3, 2, 4, 4]]

b:

[[6, 9, 8, 6],
 [7, 7, 9, 6],
 [8, 6, 5, 7]]

和,c:

[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0]]

我有一个涉及变量a和b的条件语句,如果满足a和b的条件,我想用b的值来计算c:

c[(a > 3) & (b > 8)]+=b*2

我收到一个错误提示:
Traceback (most recent call last):
  File "<interactive input>", line 1, in <module>
ValueError: non-broadcastable output operand with shape (1,) doesn't match the broadcast shape (3,4)

你有什么想法可以实现这个吗?

我想让c的输出如下所示:

[[0, 18, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0]]

谢谢!感谢大家的努力。我选择了@Psidom的答案,因为np.where对我来说最有意义,并且运行时间最短(我的实际脚本将运行这些条件几百万次)。 - bobby12345
3个回答

23
你可以使用numpy.where
np.where((a > 3) & (b > 8), c + b*2, c)
#array([[ 0, 18,  0,  0],
#       [ 0,  0,  0,  0],
#       [ 0,  0,  0,  0]])

或者用算术方法:

c + b*2 * ((a > 3) & (b > 8))
#array([[ 0, 18,  0,  0],
#       [ 0,  0,  0,  0],
#       [ 0,  0,  0,  0]])

10
问题在于你掩盖了接收部分,但却没有掩盖发送部分。结果是:
c[(a > 3) & (b > 8)]+=b*2
# ^ 1x1 matrix        ^3x4 matrix

这些维度不同。根据你的示例,如果想要执行逐元素相加,只需在右侧部分添加切片即可:

c[(a > 3) & (b > 8)]+=b<b>[(a > 3) & (b > 8)]</b>*2

或者让它更高效:

<b>mask = (a > 3) & (b > 8)</b>
c[<b>mask</b>] += b[<b>mask</b>]*2

3
稍微改变numpy的表达式即可得到所需的结果:
c += ((a > 3) & (b > 8)) * b*2

首先,我使用布尔值创建一个掩码矩阵,其条件为((a > 3) & (b > 8)),然后将该矩阵与b*2相乘,从而生成一个3x4矩阵,该矩阵可以轻松地添加到c中。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接