Numpy - 多维数组中的一维索引

3

我有一个形状为 (6, 2, 4) 的 numpy 数组:

x = np.array([[[0, 3, 2, 0],
               [1, 3, 1, 1]],

              [[3, 2, 3, 3],
               [0, 3, 2, 0]],

              [[1, 0, 3, 1],
               [3, 2, 3, 3]],

              [[0, 3, 2, 0],
               [1, 3, 2, 2]],

              [[3, 0, 3, 1],
               [1, 0, 1, 1]],

              [[1, 3, 1, 1],
               [3, 1, 3, 3]]])

我有一个名为choices的数组,内容如下:

choices = np.array([[1, 1, 1, 1],
                    [0, 1, 1, 0],
                    [1, 1, 1, 1],
                    [1, 0, 0, 0],
                    [1, 0, 1, 1],
                    [0, 0, 0, 1]])

如何最有效地使用 choices 数组只索引大小为 2 的中间维度,并获得一个形状为 (6, 4) 的新 numpy 数组?

结果应该是这样的:

[[1 3 1 1]
 [3 3 2 3]
 [3 2 3 3]
 [1 3 2 0]
 [1 0 1 1]
 [1 3 1 3]]

我试过使用 x[:, choices, :],但这并没有返回我想要的结果。我还尝试了 x.take(choices, axis=1),但没有成功。
2个回答

7
使用np.take_along_axis沿第二个轴进行索引 -
In [16]: np.take_along_axis(x,choices[:,None],axis=1)[:,0]
Out[16]: 
array([[1, 3, 1, 1],
       [3, 3, 2, 3],
       [3, 2, 3, 3],
       [1, 3, 2, 0],
       [1, 0, 1, 1],
       [1, 3, 1, 3]])

或者通过显式的integer-array索引 -

In [22]: m,n = choices.shape

In [23]: x[np.arange(m)[:,None],choices,np.arange(n)]
Out[23]: 
array([[1, 3, 1, 1],
       [3, 3, 2, 3],
       [3, 2, 3, 3],
       [1, 3, 2, 0],
       [1, 0, 1, 1],
       [1, 3, 1, 3]])

谢谢!理解choices[:, None]等于choices[:, None, :]需要一些时间,这将导致一个形状为[6, 1, 4]的数组,可以与x广播! - wstcegg

-1

最近我遇到了这个问题,发现@divakar的答案很有用,但仍然想要一个通用的函数(与维数无关等),在这里:

def take_indices_along_axis(array, choices, choice_axis):
    """
    array is N dim
    choices are integer of N-1 dim
       with valuesbetween 0 and array.shape[choice_axis] - 1
    choice_axis is the axis along which you want to take indices
    """
    nb_dims = len(array.shape)
    list_indices = []
    for this_axis, this_axis_size in enumerate(array.shape):
        if this_axis == choice_axis:
            # means this is the axis along which we want to choose
            list_indices.append(choices)
            continue
        # else, we want arange(this_axis), but reshaped to match the purpose
        this_indices = np.arange(this_axis_size)
        reshape_target = [1 for _ in range(nb_dims)]
        reshape_target[this_axis] = this_axis_size # replace the corresponding axis with the right range
        del reshape_target[choice_axis] # remove the choice_axis
        list_indices.append(
            this_indices.reshape(tuple(reshape_target))
        )
    tuple_indices = tuple(list_indices)
    return array[tuple_indices]

# test it !
array = np.random.random(size=(10, 10, 10, 10))
choices = np.random.randint(10, size=(10, 10, 10))
assert take_indices_along_axis(array, choices, choice_axis=0)[5, 5, 5] == array[choices[5, 5, 5], 5, 5, 5]
assert take_indices_along_axis(array, choices, choice_axis=2)[5, 5, 5] == array[5, 5, choices[5, 5, 5], 5]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接