未知维度的Numpy索引数组?

3

我需要比较一堆不同维度的numpy数组,例如:

a = np.array([1,2,3])
b = np.array([1,2,3],[4,5,6])
assert(a == b[0])

如果我不知道a和b的形状,那么我该如何做呢?除此之外,这与IT技术有关。
len(shape(a)) == len(shape(b)) - 1 

我也不知道应该从b的哪个维度跳过,我想使用np.index_exp,但似乎并没有帮助我...

def compare_arrays(a,b,skip_row):
    u = np.index_exp[ ... ]
    assert(a[:] == b[u])

编辑 或者换句话说,如果我知道数组的形状和我想要省略的维度,我希望构建切片。如果我知道维数和位置,在哪里放置“:”和在哪里放置“0”,我该如何动态创建np.index_exp。


np.take是你正在寻找的吗? - Eelco Hoogendoorn
谢谢,np.take似乎可以在给定的轴上工作。我可能可以使用它,但是如何使用“:”来定义范围呢? - kakk11
“跳过行”或“跳过线”是什么意思?还是说你的意思是跳过一个维度? - hpaulj
@hpaulj 谢谢,我是指维度,已更正。 - kakk11
1个回答

3

我刚才在查看代码的时候,看到了apply_along_axisapply_over_axis函数的代码,研究它们是如何构建索引对象的。

我们来创建一个四维数组:

In [355]: b=np.ones((2,3,4,3),int)

使用列表 * 复制的方法,创建一个slices列表

In [356]: ind=[slice(None)]*b.ndim

In [357]: b[ind].shape    # same as b[:,:,:,:]
Out[357]: (2, 3, 4, 3)

In [358]: ind[2]=2     # replace one slice with index

In [359]: b[ind].shape   # a slice, indexing on the third dim
Out[359]: (2, 3, 3)

或者使用你的示例。
In [361]: b = np.array([1,2,3],[4,5,6])   # missing []
...
TypeError: data type not understood

In [362]: b = np.array([[1,2,3],[4,5,6]])

In [366]: ind=[slice(None)]*b.ndim    
In [367]: ind[0]=0
In [368]: a==b[ind]
Out[368]: array([ True,  True,  True], dtype=bool)

这种索引基本上与np.take相同,但是这个想法可以扩展到其他情况。
我不太明白你关于使用:的问题。请注意,在构建索引列表时,我使用slice(None)。解释器将所有索引:转换为slice对象:[start:stop:step] => slice(start, stop, step)
通常您不需要使用a[:]==b[0]a==b[0]就足够了。对于列表,alist[:]会创建一个副本,而对于数组,它什么也不做(除非在RHS上使用a[:]=...)。

非常感谢。看起来我不知道“slice()”结构,有了它我确实可以构造任何类型和任何维度的索引,太完美了! - kakk11

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接