如何在使用argmax时检测numpy数组中的平局

3
如果我有像下面这样的数组,如何检测在使用np.argmax()时至少有3个或更多值相同时的平局?
examp = np.array([[4, 0, 1, 4, 4],
                  [5, 5, 1, 5, 5],
                  [1, 2, 2, 4, 1],
                  [4, 6, 1, 2, 4],
                  [1, 4, 3, 3, 3]])

np.argmax(examp, axis=1)

这会输出:

array([0, 0, 3, 1, 1]

以第一行为例,有一个“3路并列”。 3个值都是4np.argmax返回具有最大值的第一个索引。但是,如何检测到正在进行“3路并列”,并让它使用自定义函数来决定打破平局(在至少出现“3路并列”的情况下)?
因此,第一行:发现有4个“3路并列”。自定义函数运行,以便它可以决定打破平局。
第二行:“4路并列”同样的事情发生。
第三行:只有“2路并列”,这少于至少“3路并列”的条件。可以默认使用np.argmax。

你正在为examp中的每一行取最大索引。这里没有"四方平局"。但是如果你执行examp.argmax(0),将会出现这样的平局。 - undefined
@Kevin 抱歉,我似乎误解了np.argmax()的意思。我已经重新写了问题。(而且你已经回答了,抱歉)。 - undefined
我觉得原则还是一样的,所以我的回答可能仍然有用 :) - undefined
@Kevin 我最终改用了 examp[np.r_[:5],indices] 并返回了 array([4, 5, 4, 6, 4]),这些是每行的最大值。我猜你的 examp == 是用来检查原始数组是否与这个一维数组中的任何值匹配的。但是,无论我是使用 axis=0 还是 axis=1,它都是按列进行检查,而不是逐行进行。为什么在这里改变轴没有任何影响呢? - undefined
你可能已经弄明白了,但是当你执行examp == examp[np.r_[:5],indices]时,你是在对examp[np.r_[:5],indices]的每一行进行广播,与examp中的每一行进行比较。虽然改变轴会对结果产生影响,但是我不确定你所说的具体含义。 - undefined
是的,但还是谢谢你! - undefined
2个回答

1
寻找第n大的一种方法是使用np.partition(或np.argpartition)。在这种情况下,您可以执行以下操作:
>>> n = 3  # Size of tie
>>> i = examp.argpartition([-n, -1], axis=-1)

倒数第三列和最后一列的值保证处于正确的排序顺序(因此次后一列也是如此,但仅限于这种情况)。如果这两个值相等,则存在三方平局:

>>> r = np.arange(examp.shape[0])
>>> examp[r, i[:, -n]] == examp[r, i[:, -1]]
array([ True,  True, False, False, False])

您也可以使用np.diff来计算掩码:
>>> np.diff(examp[r[:, None], i[:, [-n, -1]]], axis=1) == 0
array([[ True],
       [ True],
       [False],
       [False],
       [False]])

您可以使用np.take_along_axis代替第一个索引r来获得类似的结果:
>>> np.diff(np.take_along_axis(examp, i[:, -n::n-1], 1), axis=1) == 0
array([[ True],
       [ True],
       [False],
       [False],
       [False]])

在所有这些情况下,argmax的值只是[:, -1],因为这是数组中最大值的索引。
既然您已经在使用numpy,我强烈建议您也将自定义的打结函数向量化。我在这里提供了输出作为掩码,以便您可以尽可能高效地完成这项工作。

0

你说得没错,np.argmax 只会找到第一个最大值。虽然你可以计算有多少个这样的 argmax 存在,并基于那个数来进行逻辑判断。

indices = examp.argmax(0)
counts = (examp == examp[indices, np.r_[:3]]).sum(0)
# the same as
counts = np.count_nonzero(examp == examp[indices, np.r_[:3]], axis=0)

将返回

indices = array([0, 3, 2])
counts = array([4, 1, 2])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接