如何在NumPy中找到给定行的值的列索引?

3

我希望在numpy中执行一些相对简单的操作:

  • 如果行中有一个1,则返回包含该1的列的索引+1。
  • 如果行中有零个或多个1,则返回0。

然而,最终我得到了一个相当复杂的代码:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

one_count = (predictions == 1).sum(1)
valid_rows_idx = np.where(one_count==1)

result = np.zeros(predictions.shape[0])
for idx in valid_rows_idx:
    result[idx] = np.where(predictions[idx,:]==1)[1] + 1

如果我打印result,程序会打印[ 1. 0. 4. 0.],这是期望的结果。
我想知道是否有一种更简单的方法使用numpy编写最后一行。
1个回答

2

我不确定是否更好,但你可以尝试使用argmax。此外,你不需要使用for循环和np.where来获取有效的索引:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

idx = (predictions == 1).sum(1) == 1
result = np.zeros(predictions.shape[0])
result[idx] = (predictions[idx]==1).argmax(axis=1) + 1

In [55]: result
Out[55]: array([ 1.,  0.,  4.,  0.])

或者你可以使用 np.whereargmax 一行代码完成所有操作:

predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])

In [72]: np.where((predictions==1).sum(1)==1, (predictions==1).argmax(axis=1)+1, 0)
Out[72]: array([1, 0, 4, 0])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接