如何加速带有掩码的NumPy点积运算?

3
我有两个numpy数组,m1和m2,其中m1的大小为(nx1),m2的大小为(1xn),我想执行乘法m1.dot(m2),得到一个大小为(nxn)的矩阵m。
我想通过仅使用m1和m2中最高的k个元素,并将所有其他元素设为0(所有元素都是正数),来计算近似的m_approx。
我试图加快乘法速度,因为对于我来说,大小n很大(约10k)。我想选择一个小的k值,比如100,从而真正加快乘法。我尝试过使用numpy稀疏矩阵,它确实使点积运算变得更快,但将m1和m2转换为稀疏向量非常慢。我该怎么做呢?我觉得掩码可能是实现这一目标的一种方法,但不确定如何操作。
1个回答

3
这可以通过使用np.argpartition来获取最大的k个元素的索引,然后使用np.ix_选择并设置来自m1m2的所选元素的点积。因此,我们基本上需要分两个阶段来实现这个功能,接下来将讨论这两个阶段。

首先,获取对应于m1m2中最大的k个元素的索引,如下所示 -

m1_idx = np.argpartition(-m1,k,axis=0)[:k].ravel()
m2_idx = np.argpartition(-m2,k)[:,:k].ravel()

最后,设置输出数组。使用np.ix_m1m2索引分别沿行和列广播以选择要设置的输出数组中的元素。接下来,计算m1m2中最高k个元素之间的点积,可以使用m1_idxm2_idx进行索引从m1m2中获取这些元素。
out = np.zeros((n,n))
out[np.ix_(m1_idx,m2_idx)] = np.dot(m1[m1_idx],m2[:,m2_idx])

让我们通过对另一个实现运行它来验证该实现,该实现显式设置较低的 n-k 元素为 0m1m2 中,然后执行点积。以下是执行检查的示例运行 -

1)输入:

In [170]: m1
Out[170]: 
array([[ 0.26980423],
       [ 0.30698416],
       [ 0.60391089],
       [ 0.73246763],
       [ 0.35276247]])

In [171]: m2
Out[171]: array([[ 0.30523552, 0.87411242, 0.01071218, 0.81835438, 0.21693231]])

In [172]: k = 2

2) 运行建议的实现:

In [173]: # Proposed solution code
     ...: m1_idx = np.argpartition(-m1,k,axis=0)[:k].ravel()
     ...: m2_idx = np.argpartition(-m2,k)[:,:k].ravel()
     ...: out = np.zeros((n,n))
     ...: out[np.ix_(m1_idx,m2_idx)] = np.dot(m1[m1_idx],m2[:,m2_idx])
     ...: 

3)使用替代实现来获取输出:

In [174]: # Explicit setting of lower n-k elements to zeros for m1 and m2
     ...: m1[np.argpartition(-m1,k,axis=0)[k:]] = 0
     ...: m2[:,np.argpartition(-m2,k)[:,k:].ravel()] = 0
     ...: 

In [175]: m1  # Verify m1 and m2 have lower n-k elements set to 0s
Out[175]: 
array([[ 0.        ],
       [ 0.        ],
       [ 0.60391089],
       [ 0.73246763],
       [ 0.        ]])

In [176]: m2
Out[176]: array([[ 0.       , 0.87411242, 0.        , 0.81835438, 0.        ]])

In [177]: m1.dot(m2)  # Use m1.dot(m2) to directly get output. This is expensive.
Out[177]: 
array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.52788601,  0.        ,  0.49421312,  0.        ],
       [ 0.        ,  0.64025905,  0.        ,  0.59941809,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])

4) 验证我们提出的实现:

In [178]: out   # Print output from proposed solution obtained earlier
Out[178]: 
array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.52788601,  0.        ,  0.49421312,  0.        ],
       [ 0.        ,  0.64025905,  0.        ,  0.59941809,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])

正是我所需要的...之前不知道有np.ix_! - A.D
@Adi 很高兴能帮忙! :) - Divakar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接