如何在numpy数组(Python)中每行获取前k个最大值?

6

给定下面这种形式的numpy数组:

x = [[4.,3.,2.,1.,8.],[1.2,3.1,0.,9.2,5.5],[0.2,7.0,4.4,0.2,1.3]]

有没有一种方法可以在Python中保留每行的前三个值并将其他值设为零(无需使用显式循环)。 在上面的示例中,结果应该是:

x = [[4.,3.,0.,0.,8.],[0.,3.1,0.,9.2,5.5],[0.0,7.0,4.4,0.0,1.3]]

一个示例的代码

import numpy as np
arr = np.array([1.2,3.1,0.,9.2,5.5,3.2])
indexes=arr.argsort()[-3:][::-1]
a = list(range(6))
A=set(indexes); B=set(a)
zero_ind=(B.difference(A)) 
arr[list(zero_ind)]=0

输出:
array([0. , 0. , 0. , 9.2, 5.5, 3.2])

以下是我针对一个一维numpy数组的样本代码(有很多行)进行的翻译。如果要遍历numpy数组的每一行并重复执行此计算,这将是非常昂贵的。是否有更简单的方法?

1
什么问题?你的代码在哪里? - AMC
以下内容是否有帮助?https://dev59.com/H2cs5IYBdhLWcg3wLQ44 - Kevin Liu
4个回答

4

这是一段完全基于向量的代码,没有使用任何第三方库,包括 numpy。它使用了 numpy 的 argpartition 函数高效地查找第 k 个值。其他用例请参见 此答案

def truncate_top_k(x, k, inplace=False):
    m, n = x.shape
    # get (unsorted) indices of top-k values
    topk_indices = numpy.argpartition(x, -k, axis=1)[:, -k:]
    # get k-th value
    rows, _ = numpy.indices((m, k))
    kth_vals = x[rows, topk_indices].min(axis=1)
    # get boolean mask of values smaller than k-th
    is_smaller_than_kth = x < kth_vals[:, None]
    # replace mask by 0
    if not inplace:
        return numpy.where(is_smaller_than_kth, 0, x)
    x[is_smaller_than_kth] = 0
    return x    

1
使用 np.apply_along_axis 函数,对给定轴上的 1-D 切片应用一个函数。
import numpy as np

def top_k_values(array):
    indexes = array.argsort()[-3:][::-1]
    A = set(indexes)
    B = set(list(range(array.shape[0])))
    array[list(B.difference(A))]=0
    return array

arr = np.array([[4.,3.,2.,1.,8.],[1.2,3.1,0.,9.2,5.5],[0.2,7.0,4.4,0.2,1.3]])
result = np.apply_along_axis(top_k_values, 1, arr)
print(result)

输出

[[4.  3.  0.  0.  8. ]
 [0.  3.1 0.  9.2 5.5]
 [0.  7.  4.4 0.  1.3]]

1
def top_k(arr, k, axis = 0):
    top_k_idx =  = np.take_along_axis(np.argpartition(arr, -k, axis = axis), 
                                      np.arange(-k,-1), 
                                      axis = axis)  # indices of top k values in axis
    out = np.zeros.like(arr)                        # create zero array
    np.put_along_axis(out, top_k_idx,               # put idx values of arr in out
                      np.take_along_axis(arr, top_k_idx, axis = axis), 
                      axis = axis)
    return out

这应该适用于任意的 axisk,但不支持原地更改。如果你想要原地更改,那么就要简单一些:

def top_k(arr, k, axis = 0):
    remove_idx =  = np.take_along_axis(np.argpartition(arr, -k, axis = axis), 
                                           np.arange(arr.shape[axis] - k), 
                                           axis = axis)    # indices to remove
    np.put_along_axis(out, remove_idx, 0, axis = axis)     # put 0 in indices

0
这里有一个替代方案,它使用列表推导式来遍历您的数组并应用 keep_top_3 函数。
import numpy as np
import heapq

def keep_top_3(arr): 
    smallest = heapq.nlargest(3, arr)[-1]  # find the top 3 and use the smallest as cut off
    arr[arr < smallest] = 0 # replace anything lower than the cut off with 0
    return arr 

x = [[4.,3.,2.,1.,8.],[1.2,3.1,0.,9.2,5.5],[0.2,7.0,4.4,0.2,1.3]]
result = [keep_top_3(np.array(arr)) for arr  in x]

希望这可以帮到您:)


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接