如何在NumPy中收集特定索引的元素?

6
我想收集指定轴上指定索引的元素,就像以下一样。
x = [[1,2,3], [4,5,6]]
index = [[2,1], [0, 1]]
x[:, index] = [[3, 2], [4, 5]]

这实质上是pytorch中的gather操作,但是你知道,在numpy中无法以这种方式实现。我想知道在numpy中是否有类似的“gather”操作?


这回答您的问题吗? [如何在numpy中进行scatter和gather操作?] (https://dev59.com/llYO5IYBdhLWcg3wRvkh) - iacob
4个回答

9

numpy.take_along_axis 是我需要的内容,可以根据索引取出元素。它可以像 PyTorch 中的 gather 方法一样使用。

以下是手册中的示例:

>>> a = np.array([[10, 30, 20], [60, 40, 50]])
>>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1)
>>> ai
array([[1],
       [0]])
>>> np.take_along_axis(a, ai, axis=1)
array([[30],
       [60]])


6

我之前写了一段代码,用于在Numpy中复制PyTorch的gather功能。在这种情况下,self代表你的x

def gather(self, dim, index):
    """
    Gathers values along an axis specified by ``dim``.

    For a 3-D tensor the output is specified by:
        out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
        out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
        out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

    Parameters
    ----------
    dim:
        The axis along which to index
    index:
        A tensor of indices of elements to gather

    Returns
    -------
    Output Tensor
    """
    idx_xsection_shape = index.shape[:dim] + \
        index.shape[dim + 1:]
    self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
    if idx_xsection_shape != self_xsection_shape:
        raise ValueError("Except for dimension " + str(dim) +
                         ", all dimensions of index and self should be the same size")
    if index.dtype != np.dtype('int_'):
        raise TypeError("The values of index must be integers")
    data_swaped = np.swapaxes(self, 0, dim)
    index_swaped = np.swapaxes(index, 0, dim)
    gathered = np.choose(index_swaped, data_swaped)
    return np.swapaxes(gathered, 0, dim)

以下是测试用例:

# Test 1
    t = np.array([[65, 17], [14, 25], [76, 22]])
    idx = np.array([[0], [1], [0]])
    dim = 1
    result = gather(t, dim=dim, index=idx)
    expected = np.array([[65], [25], [76]])
    print(np.array_equal(result, expected))

# Test 2
    t = np.array([[47, 74, 44], [56, 9, 37]])
    idx = np.array([[0, 0, 1], [1, 1, 0], [0, 1, 0]])
    dim = 0
    result = gather(t, dim=dim, index=idx)
    expected = np.array([[47, 74, 37], [56, 9, 44.], [47, 9, 44]])
    print(np.array_equal(result, expected))

3
使用numpy.take()函数,它具有大部分PyTorch的gather函数功能。

2
>>> x = np.array([[1,2,3], [4,5,6]])
>>> index = np.array([[2,1], [0, 1]])
>>> x_axis_index=np.tile(np.arange(len(x)), (index.shape[1],1)).transpose() 
>>> print x_axis_index
[[0 0]
 [1 1]]
>>> print x[x_axis_index,index]
[[3 2]
 [4 5]]

注意:也可以使用 np.arange(len(x)),不确定是否更可取。 - Andy Hayden
注意:range(x.shape [0])和range(len(x))会给出一个列表,而np.arange(len(x))和np.arange(x.shape [0])会给出一个数组。数组和列表具有相同的元素。 - Sam17
我想我的问题/陈述更多关于性能方面,在一个非常大的数组中,我的猜测是使用np.range进行索引会更快(len与shape肯定无关)。 - Andy Hayden
{btsdaf} - ZEWEI CHU

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接