用子矩阵替换numpy矩阵元素

3

假设我有一个方形索引矩阵,例如:

idxs = np.array([[1, 1],
                 [0, 1]])

以及一个方阵数组,大小与彼此相同(不一定与idxs的大小相同):

mats = array([[[ 0. ,  0. ],
               [ 0. ,  0.5]],

              [[ 1. ,  0.3],
               [ 1. ,  1. ]]])

我想用mats中相应的矩阵替换idxs中的每个索引,以获得:

array([[ 1. ,  0.3,  1. ,  0.3],
       [ 1. ,  1. ,  1. ,  1. ],
       [ 0. ,  0. ,  1. ,  0.3],
       [ 0. ,  0.5,  1. ,  1. ]])

mats[idxs]会给我一个嵌套版本的结果:

array([[[[ 1. ,  0.3],
         [ 1. ,  1. ]],

        [[ 1. ,  0.3],
         [ 1. ,  1. ]]],


       [[[ 0. ,  0. ],
         [ 0. ,  0.5]],

        [[ 1. ,  0.3],
         [ 1. ,  1. ]]]])

于是我尝试使用reshape,但是没有成功!mats[idxs].reshape(4,4)返回:

array([[ 1. ,  0.3,  1. ,  1. ],
       [ 1. ,  0.3,  1. ,  1. ],
       [ 0. ,  0. ,  0. ,  0.5],
       [ 1. ,  0.3,  1. ,  1. ]])

如果有帮助的话,我发现skimage.util.view_as_blocks正好是我需要的相反操作(它可以将我的期望结果转换成嵌套的mats[idxs]格式)。
有没有(希望非常)快速的方法?对于应用程序,我的mats仍然只有几个小矩阵,但我的idxs将是一个达到2^15阶的方阵,在这种情况下,我将替换超过一百万个索引来创建一个2^16阶的新矩阵。
非常感谢您的帮助!

糟糕!是的,应该这样做 - 谢谢! - GeoffChurch
1个回答

3
我们使用这些索引来访问输入数组的第一个轴。为了得到2D输出,我们只需要重新排列轴并在之后进行重塑。因此,一种方法是使用np.transpose/np.swapaxesnp.reshape,如下所示 -
mats[idxs].swapaxes(1,2).reshape(-1,mats.shape[-1]*idxs.shape[-1])

示例运行 -

In [83]: mats
Out[83]: 
array([[[1, 1],
        [7, 1]],

       [[6, 6],
        [5, 8]],

       [[7, 1],
        [6, 0]],

       [[2, 7],
        [0, 4]]])

In [84]: idxs
Out[84]: 
array([[2, 3],
       [0, 3],
       [1, 2]])

In [85]: mats[idxs].swapaxes(1,2).reshape(-1,mats.shape[-1]*idxs.shape[-1])
Out[85]: 
array([[7, 1, 2, 7],
       [6, 0, 0, 4],
       [1, 1, 2, 7],
       [7, 1, 0, 4],
       [6, 6, 7, 1],
       [5, 8, 6, 0]])

使用np.take来提高性能,针对重复索引

对于重复索引,为了提高性能,我们最好使用np.take通过沿着axis=0进行索引。让我们列出这两个方法,并使用具有许多重复索引的idxs计时。

函数定义 -

def simply_indexing_based(mats, idxs):
    ncols = mats.shape[-1]*idxs.shape[-1]
    return mats[idxs].swapaxes(1,2).reshape(-1,ncols)

def take_based(mats, idxs):np.take(mats,idxs,axis=0)
    ncols = mats.shape[-1]*idxs.shape[-1]
    return np.take(mats,idxs,axis=0).swapaxes(1,2).reshape(-1,ncols)

运行时测试 -

In [156]: mats = np.random.randint(0,9,(10,2,2))

In [157]: idxs = np.random.randint(0,10,(1000,1000))
                 # This ensures many repeated indices

In [158]: out1 = simply_indexing_based(mats, idxs)

In [159]: out2 = take_based(mats, idxs)

In [160]: np.allclose(out1, out2)
Out[160]: True

In [161]: %timeit simply_indexing_based(mats, idxs)
10 loops, best of 3: 41.2 ms per loop

In [162]: %timeit take_based(mats, idxs)
10 loops, best of 3: 27.3 ms per loop

因此,我们看到总体上有了1.5倍以上的改进。
为了感受使用np.take的改进,让我们只计时索引部分 -
In [168]: %timeit mats[idxs]
10 loops, best of 3: 22.8 ms per loop

In [169]: %timeit np.take(mats,idxs,axis=0)
100 loops, best of 3: 8.88 ms per loop

对于这些数据大小,它是2.5倍以上。不错!

非常感谢!这很完美地发挥了作用 - 我将牢记swapaxes以备将来参考。 - GeoffChurch
@JubJub 加入了一个使用 np.take 的改进。快去看看吧! - Divakar
哦,我的天啊,非常感谢!这对我帮助很大 :D - GeoffChurch

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接