我希望在numpy中执行一些相对简单的操作:
- 如果行中有一个1,则返回包含该1的列的索引+1。
- 如果行中有零个或多个1,则返回0。
然而,最终我得到了一个相当复杂的代码:
predictions = np.array([[1,-1,-1,-1],[-1,1,1,-1],[-1,-1,-1,1],[-1,-1,-1,-1]])
one_count = (predictions == 1).sum(1)
valid_rows_idx = np.where(one_count==1)
result = np.zeros(predictions.shape[0])
for idx in valid_rows_idx:
result[idx] = np.where(predictions[idx,:]==1)[1] + 1
如果我打印
result
,程序会打印[ 1. 0. 4. 0.]
,这是期望的结果。我想知道是否有一种更简单的方法使用numpy编写最后一行。