Numpy数组的N维数组和(N-k)维数组的“小于/大于”比较

3

给定两个数组a=np.array([[1, 3], [3, 4]])b=np.array([2, 2])

目标:通过操作如a>b获取数组np.array([False, True])。即行比较(如果每对元素都满足>运算符,则为True,否则为False),而不是逐个元素比较(即我不想得到np.array([[False, True], [True, True]]))。

对于3-D和(可选的)N维数组也是类似的。 例如:

a1 = np.array([[[1, 2, 1], [2, 3, 2]], [[3, 4, 3], [4, 3, 4]]])
b1 = np.array([1, 1, 1])

如何实现像 a1 > b1 这样的操作返回 np.array([[False, True], [True, True]])

有什么方法吗?


这些都是整数数组吗? - Divakar
1
我无法看到基于元素类型的比较限制。因此,所有可通过numpy进行比较的类型都是允许的。 - kupgov
1个回答

6
解决方案找到了:另外使用numpy.all函数。

我的示例用法:

a=np.array([[1, 3], [3, 4]])
b=np.array([2, 2])
numpy.all(a > b, axis=1)

结果:

array([False,  True], dtype=bool)

并且

a1 = np.array([[[1, 2, 1], [2, 3, 2]], [[3, 4, 3], [4, 3, 4]]])
b1 = np.array([1, 1, 1])
numpy.all(a1 > b1, axis=2)

结果:

array([[False,  True],
       [ True,  True]], dtype=bool)

numpy.all 还允许传递多个轴(作为整数的元组),因此可以用于任何维度。

此外,numpy 还允许使用 numpy 数组的 ndarray.all 方法。然后,示例可以重写为(a>b).all(axis=1)(a1>b1).all(axis=2),分别如下。


3
只要沿着最后一个轴减少,(A>B).all(axis=-1) 就是适用于任何 Ndarray 的通用方式。 - Divakar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接