Numpy切片与边界检查

9
numpy是否提供了在对数组进行切片时进行边界检查的方法?例如,如果我执行以下操作:
arr = np.ones([2,2])
sliced_arr = arr[0:5,:]

这个切片是可以的,它会返回整个 arr 数组,即使我请求不存在的索引。有没有其他在 numpy 中切片的方法,如果我尝试切片超出数组边界,就会抛出错误?


令人惊讶的是,numpy本身不执行边界检查,但实现起来并不太难。类似np.any(np.array([i,j]) > arr.shape)这样的代码可能就可以解决问题。 - Thomas Kühn
@ThomasKühn 这需要你“充实”切片,这并不是绝对简单的事情,需要一些时间和内存... - jdehesa
这可能与这个有关吗? - Thomas Kühn
1
Python 列表中允许越界切片优先级较高。实际上,这似乎并不成问题。如果我们想要剩余部分,我们使用 arr[0:, :]。如果需要一个错误,则可以将其包装在 if arr.shape[0]<5 测试中。 - hpaulj
2个回答

4
这篇文章比预期的要长一些,但是你可以编写自己的包装器来检查获取操作,确保切片不超过限制(NumPy已经检查了不是切片的索引参数)。我认为我在这里涵盖了所有情况(省略号、np.newaxis、负步长...),尽管可能还有一些边缘情况失败。
import numpy as np

# Wrapping function
def bounds_checked_slice(arr):
    return SliceBoundsChecker(arr)

# Wrapper that checks that indexing slices are within bounds of the array
class SliceBoundsChecker:

    def __init__(self, arr):
        self._arr = np.asarray(arr)

    def __getitem__(self, args):
        # Slice bounds checking
        self._check_slice_bounds(args)
        return self._arr.__getitem__(args)

    def __setitem__(self, args, value):
        # Slice bounds checking
        self._check_slice_bounds(args)
        return self._arr.__setitem__(args, value)

    # Check slices in the arguments are within bounds
    def _check_slice_bounds(self, args):
        if not isinstance(args, tuple):
            args = (args,)
        # Iterate through indexing arguments
        arr_dim = 0
        i_arg = 0
        for i_arg, arg in enumerate(args):
            if isinstance(arg, slice):
                self._check_slice(arg, arr_dim)
                arr_dim += 1
            elif arg is Ellipsis:
                break
            elif arg is np.newaxis:
                pass
            else:
                arr_dim += 1
        # Go backwards from end after ellipsis if necessary
        arr_dim = -1
        for arg in args[:i_arg:-1]:
            if isinstance(arg, slice):
                self._check_slice(arg, arr_dim)
                arr_dim -= 1
            elif arg is Ellipsis:
                raise IndexError("an index can only have a single ellipsis ('...')")
            elif arg is np.newaxis:
                pass
            else:
                arr_dim -= 1

    # Check a single slice
    def _check_slice(self, slice, axis):
        size = self._arr.shape[axis]
        start = slice.start
        stop = slice.stop
        step = slice.step if slice.step is not None else 1
        if step == 0:
            raise ValueError("slice step cannot be zero")
        bad_slice = False
        if start is not None:
            start = start if start >= 0 else start + size
            bad_slice |= start < 0 or start >= size
        else:
            start = 0 if step > 0 else size - 1
        if stop is not None:
            stop = stop if stop >= 0 else stop + size
            bad_slice |= (stop < 0 or stop > size) if step > 0 else (stop < 0 or stop >= size)
        else:
            stop = size if step > 0 else -1
        if bad_slice:
            raise IndexError("slice {}:{}:{} is out of bounds for axis {} with size {}".format(
                slice.start if slice.start is not None else '',
                slice.stop if slice.stop is not None else '',
                slice.step if slice.step is not None else '',
                axis % self._arr.ndim, size))

一个小例子:
import numpy as np

a = np.arange(24).reshape(4, 6)
print(bounds_checked_slice(a)[:2, 1:5])
# [[ 1  2  3  4]
#  [ 7  8  9 10]]
bounds_checked_slice(a)[:2, 4:10]
# IndexError: slice 4:10: is out of bounds for axis 1 with size 6

如果你愿意,甚至可以将其作为 ndarray的子类,这样你就可以默认获得这种行为,而不必每次都包装数组。
另外,请注意可能会有一些变化,你可能认为“越界”的定义有所不同。上面的代码认为即使超出一个索引也是越界,这意味着你不能使用arr[len(arr):]获取空切片。如果你想要略微不同的行为,原则上你可以编辑代码。

3
如果你使用了range而不是通常的切片符号,你可以获得预期的行为。例如,对于有效的切片:
arr[range(2),:]

array([[1., 1.],
       [1., 1.]])

如果我们尝试使用例如以下方式进行切片:

arr[range(5),:]

会抛出以下错误:
``` IndexError: 索引 2 超出了大小为 2 的范围 ```
我猜测这个错误的原因是使用常规切片符号进行切片在`numpy`数组和列表中是一种基本属性,因此当我们尝试使用错误的索引进行切片时,它不会抛出索引超出范围的错误,而是自动调整为最近的有效索引。而对于使用一个不可变对象 `range` 进行切片时,显然没有考虑到这一点。

4
尽管如此,对于大数组而言,这种方法的计算成本较切片更高,并且它会生成一个新的数组而不是视图。 - jdehesa
是的,然而鉴于我所知道的情况,直接切片可能会因为尝试使用错误的索引而不抛出错误,因此需要另一种替代方法。这是我能想到的最简单的方法。 - yatu
是的,我只发现np.take允许选择一个mode来处理越界值,但它不适用于切片,所以它与您发布的内容完全相同。 - jdehesa
此外,当对多个索引进行切片时,这种方法就不再适用了。例如,x = np.arange(25).reshape(5, 5); x[2:4, 2:4]x[range(2, 4), range(2, 4)] 是不同的。 - Prasad Raghavendra

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接