在NumPy数组中,有没有一种优雅的方法来检查是否可以请求索引?

4

我正在寻找一种优雅的方法来检查给定的索引是否在numpy数组中(例如用于网格上的BFS算法)。

以下代码可以实现我的要求:

import numpy as np

def isValid(np_shape: tuple, index: tuple):
    if min(index) < 0:
        return False
    for ind,sh in zip(index,np_shape):
        if ind >= sh:
            return False
    return True

arr = np.zeros((3,5))
print(isValid(arr.shape,(0,0))) # True
print(isValid(arr.shape,(2,4))) # True
print(isValid(arr.shape,(4,4))) # False

但我更喜欢内建的或比编写包含Python for循环的自己的函数更优雅的方法(天哪)。


也许可以这样写:def isValid(np_shape: tuple, index: tuple): return (0, 0) <= index <= np_shape - DarrylG
谢谢,但不起作用。我已经尝试过了:(0,0) <= (4,-1) 由于某种原因返回True.. - schajan
但是,这种形式的 isValid 返回 False(与发布的问题相同)。 - DarrylG
1
形状和索引都是元组,而且很短。不用担心循环。无循环“规则”适用于numpy数组,特别是长数组。 - hpaulj
2个回答

3

你可以尝试以下方法:

def isValid(np_shape: tuple, index: tuple):
    index = np.array(index)
    return (index >= 0).all() and (index < arr.shape).all()

arr = np.zeros((3,5))
print(isValid(arr.shape,(0,0))) # True
print(isValid(arr.shape,(2,4))) # True
print(isValid(arr.shape,(4,4))) # False

2

我对这些答案进行了基准测试,并得出结论,实际上在我的代码中提供的显式for循环表现最佳。

Dmitri的解决方案之所以错误有几个原因(tuple1 < tuple2仅比较第一个值;像np.all(ni < sh for ind,sh in zip(index,np_shape))这样的想法失败,因为all的输入返回一个生成器而不是列表等)。

@mozway的解决方案是正确的,但是所有类型转换使它变慢了很多。此外,它始终需要考虑所有数字进行转换,而显式循环可以更早地停止,我想。

这是我的基准测试(方法0是@mozway的解决方案,方法1是我的解决方案):

enter image description here


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接