沿动态指定的轴切片numpy数组

59

我希望能够动态地沿着特定的轴切片一个Numpy数组。给出以下代码:

axis = 2
start = 5
end = 10

我希望达到与此相同的结果:

# m is some matrix
m[:,:,5:10]

使用类似这样的内容:

slc = tuple(:,) * len(m.shape)
slc[axis] = slice(start,end)
m[slc]

但是:值不能放在元组中,因此我无法想出如何构建片段。


m 长什么样? - MrAlias
3
没问题。问题在于如何动态构建一个切片。 - Sean Mackesey
链接问题:https://dev59.com/3ZXfa4cB1Zd3GeqPnP4v - hintze
7个回答

66

因为没有明确说明(而我也在寻找),所以需要说明一下:

等同于:

a = my_array[:, :, :, 8]
b = my_array[:, :, :, 2:7]

是:

a = my_array.take(indices=8, axis=3)
b = my_array.take(indices=range(2, 7), axis=3)

3
这应该是答案。 - Sibbs Gambling
8
使用np.take会创建一个新数组,并从原始数组中复制数据。这可能不是您想要的(对于大型数组,额外的内存使用可能很大)。 - Leland Hepworth
这个答案支持负索引吗,例如 slice(1, -1) - user357269
是的,如果您无论如何都要复制数据,这是一个优雅的解决方案。是的,它支持负索引。然而,当您只需要一个视图时,内存和速度成本可能会很高。例如,从一个1000x1000的RGB图像中取出一个通道(在我的机器上)需要6.4毫秒,而使用切片只需要198纳秒。这意味着在这种特定情况下,np.take的速度比切片慢了3200万倍(使用%timeit测量)。 - undefined

42

我认为一种方法是使用slice(None)

>>> m = np.arange(2*3*5).reshape((2,3,5))
>>> axis, start, end = 2, 1, 3
>>> target = m[:, :, 1:3]
>>> target
array([[[ 1,  2],
        [ 6,  7],
        [11, 12]],

       [[16, 17],
        [21, 22],
        [26, 27]]])
>>> slc = [slice(None)] * len(m.shape)
>>> slc[axis] = slice(start, end)
>>> np.allclose(m[slc], target)
True

我有一种模糊的感觉,好像以前用过一个函数来做这件事情,但是现在找不到它了...


谢谢,问题解决了。slice(None) 显然等同于 : - Sean Mackesey
3
虽然在Numpy中,使用列表进行索引(例如m[slc])现已被弃用并会抛出一个FutureWarning,但是有一种很好的解决方案。建议采用FutureWarning中提到的修复方法,将列表转换为元组,即m[tuple(slc)] - Erlend Magnus Viggen
很好。当使用netCDF4从netCDF数据集中提取时,此答案也适用,其中numpy.take不可用。 - Klimaat
5
使用 m.ndim 替代 len(m.shape) - nth

18

虽然我来晚了,但是我有一个备选的切片函数,比其他答案中的表现略好:

def array_slice(a, axis, start, end, step=1):
    return a[(slice(None),) * (axis % a.ndim) + (slice(start, end, step),)]

这里有一段测试每个答案的代码。每个版本都标有发布答案的用户的名称:
import numpy as np
from timeit import timeit

def answer_dms(a, axis, start, end, step=1):
    slc = [slice(None)] * len(a.shape)
    slc[axis] = slice(start, end, step)
    return a[slc]

def answer_smiglo(a, axis, start, end, step=1):
    return a.take(indices=range(start, end, step), axis=axis)

def answer_eelkespaak(a, axis, start, end, step=1):
    sl = [slice(None)] * m.ndim
    sl[axis] = slice(start, end, step)
    return a[tuple(sl)]

def answer_clemisch(a, axis, start, end, step=1):
    a = np.moveaxis(a, axis, 0)
    a = a[start:end:step]
    return np.moveaxis(a, 0, axis)

def answer_leland(a, axis, start, end, step=1):
    return a[(slice(None),) * (axis % a.ndim) + (slice(start, end, step),)]

if __name__ == '__main__':
    m = np.arange(2*3*5).reshape((2,3,5))
    axis, start, end = 2, 1, 3
    target = m[:, :, 1:3]
    for answer in (answer_dms, answer_smiglo, answer_eelkespaak,
                   answer_clemisch, answer_leland):
        print(answer.__name__)
        m_copy = m.copy()
        m_slice = answer(m_copy, axis, start, end)
        c = np.allclose(target, m_slice)
        print('correct: %s' %c)
        t = timeit('answer(m, axis, start, end)',
                   setup='from __main__ import answer, m, axis, start, end')
        print('time:    %s' %t)
        try:
            m_slice[0,0,0] = 42
        except:
            print('method:  view_only')
        finally:
            if np.allclose(m, m_copy):
                print('method:  copy')
            else:
                print('method:  in_place')
        print('')

以下是结果:
answer_dms

Warning (from warnings module):
  File "C:\Users\leland.hepworth\test_dynamic_slicing.py", line 7
    return a[slc]
FutureWarning: Using a non-tuple sequence for multidimensional indexing is 
deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be 
interpreted as an array index, `arr[np.array(seq)]`, which will result either in an 
error or a different result.
correct: True
time:    2.2048302
method:  in_place

answer_smiglo
correct: True
time:    5.9013344
method:  copy

answer_eelkespaak
correct: True
time:    1.1219435999999998
method:  in_place

answer_clemisch
correct: True
time:    13.707583699999999
method:  in_place

answer_leland
correct: True
time:    0.9781496999999995
method:  in_place
  • DSM的回答中在评论中提出了一些改进建议。
  • EelkeSpaak的回答采用了这些改进方法,避免了警告并且速度更快。
  • Śmigło的回答涉及到np.take,结果较差,虽然它不是只读视图,但它确实创建了一个副本。
  • clemisch的回答涉及到np.moveaxis,完成时间最长,但令人惊讶的是,它引用了先前数组的内存位置。
  • 我的答案消除了中间切片列表的需要。当切片轴朝向开头时,它还使用较短的切片索引。这样可以得到最快的结果,并且随着轴越接近0,还有额外的改进。

我还为每个版本添加了一个step参数,以防您需要它。


1
感谢您提供这个有趣的比较!我现在才看到。我很惊讶moveaxis如此之慢,因为我认为它与slice对象列表做的事情是一样的。这绝对是一个好的知识点! - clemisch
你的函数不支持 axis=-1,所以选择 answer_eelkespaak 是最佳选项。 - dgrigonis

17

虽然来晚了一些,但在Numpy中实现这个功能的默认方法是numpy.take。 但是,总是会复制数据(因为它支持花式索引,它总是假设这是可能的)。 为了避免这种情况(在许多情况下,您将需要数据的视图,而不是副本),可以回退到其他答案中已经提到的slice(None)选项,可能还可以将其包装在一个好的函数中:

def simple_slice(arr, inds, axis):
    # this does the same as np.take() except only supports simple slicing, not
    # advanced indexing, and thus is much faster
    sl = [slice(None)] * arr.ndim
    sl[axis] = inds
    return arr[tuple(sl)]

2
如果您能明确说明inds参数所期望的数据类型,那将会很有帮助。 - Spencer Mathews

8

有一种优雅的方法可以访问数组 x 的任意轴 n:使用 numpy.moveaxis¹ 将感兴趣的轴移动到最前面。

x_move = np.moveaxis(x, n, 0)  # move n-th axis to front
x_move[start:end]              # access n-th axis

问题是你很可能需要对使用x_move [start:end]的其他数组应用moveaxis以保持轴顺序一致。 数组x_move 仅是视图,因此对其前轴所做的每个更改都对应于在第n轴中更改x(即可读/写x_move)。

1)您还可以使用swapaxes而不必担心n 0 的顺序,与moveaxis(x,n,0)相反。 我更喜欢moveaxis而不是swapaxes,因为它仅更改涉及n的顺序。


2

这实在是太晚了!但我得到了 Leland 的答案,并将其扩展为适用于多个轴和切片参数。以下是该函数的详细版本。

from numpy import *

def slicer(a, axis=None, slices=None):
    if not hasattr(axis, '__iter__'):
        axis = [axis]
    if not hasattr(slices, '__iter__') or len(slices) != len(axis):
        slices = [slices]
    slices = [ sl if isinstance(sl,slice) else slice(*sl) for sl in slices ]
    mask = []
    fixed_axis = array(axis) % a.ndim
    case = dict(zip(fixed_axis, slices))
    for dim, size in enumerate(a.shape):
        mask.append( case[dim] if dim in fixed_axis else slice(None) )
    return a[tuple(mask)]

它适用于可变数量的轴,并以切片元组作为输入

>>> a = array( range(10**4) ).reshape(10,10,10,10)
>>> slicer( a, -2, (1,3) ).shape
(10, 10, 2, 10)
>>> slicer( a, axis=(-1,-2,0), slices=((3,), s_[:5], slice(3,None)) ).shape
(7, 10, 5, 3)

一个稍微紧凑一些的版本
def slicer2(a, axis=None, slices=None):
    ensure_iter = lambda l: l if hasattr(l, '__iter__') else [l]
    axis = array(ensure_iter(axis)) % a.ndim
    if len(ensure_iter(slices)) != len(axis):
        slices = [slices]
    slice_selector = dict(zip(axis, [ sl if isinstance(sl,slice) else slice(*sl) for sl in ensure_iter(slices) ]))
    element = lambda dim_: slice_selector[dim_] if dim_ in slice_selector.keys() else slice(None)
    return a[( element(dim) for dim in range(a.ndim) )]

1

我没有看到任何提到Ellipsis对象的评论,所以想要提供另一种略微不同的解决方案,适用于正向和负向轴。这个解决方案有一个函数,可以生成沿着所需维度的切片。

我包含这个主要原因是确保学习动态切片的人也学会使用省略号。

def _slice_along_axis(slice_inds,axis=-1):
    '''
    Returns a slice such that the 1d slice provided by slice_inds, slices along the dimension provided.
    '''
    from_end=False
    if axis<0: # choosing axis at the end
        from_end = True
        axis = -1-axis
    explicit_inds_slice = axis*(slice(None),) 
    if from_end:
        return (Ellipsis,slice_inds) + explicit_inds_slice
    else:
        return  explicit_inds_slice + (slice_inds,)

为了使用它,可以像平常一样调用切片。 例子:
a = my_array[:, :, :, 8]
b = my_array[:, :, :, 2:7]
c = my_array[...,3] # (equivalent to my_array.take(axis=-1,indices=3)

等同于

a = my_array[_slice_along_axis(8,axis=3)]
b = my_array[_slice_along_axis(slice(2,7),axis=3)]
c = my_array[_slice_along_axis(3,axis=-1)]

这种方法的一个优点是,切片可以生成一次,然后在其他数组中使用,即使其他数组中的维数不同(负轴的情况)。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接