理解numpy.where函数

4

我正在阅读 numpy.where(condition[, x, y])文档,但我无法理解其中的小例子:

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
Out: (array([2, 2, 2]), array([0, 1, 2]))

有人能解释一下结果是怎么来的吗?

2个回答

6
第一个数组(array([2, 2, 2]))是行的索引,第二个数组(array([0, 1, 2]))是那些大于5的值所在的列。
您可以使用zip来获取这些值的确切索引:
>>> zip(*np.where( x > 5 ))
[(2, 0), (2, 1), (2, 2)]

或者使用 np.dstack
>>> np.dstack(np.where( x > 5 ))
array([[[2, 0],
        [2, 1],
        [2, 2]]])

  • 表示什么?
- pseudomonas
1
@Raghuram 这是一个原地解包操作符,它可以解包可迭代对象。在这种情况下,它将从 np.where 中传递的项和索引元组传递给 zip 函数。 - Mazdak
非常感谢。今天学到了新东西! :) - pseudomonas

2

它会将坐标打印到您的条件中

import numpy as np

x = np.arange(9.).reshape(3, 3)
print x
print np.where( x > 5 )

print x命令的输出结果如下:

[[ 0.  1.  2.]
 [ 3.  4.  5.]
 [ 6.  7.  8.]]

np.where(x > 5)会打印所有大于5的元素的索引位置。

(array([2, 2, 2]), array([0, 1, 2]))

其中 2,0 等于 6,2,1 等于 7,2,2 等于 8。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接