使用numpy的tensordot进行张量乘法

15
我有一个张量U,由n个维度为(d,k)的矩阵组成,还有一个维度为(k,n)的矩阵V。
我想将它们相乘,使结果返回一个维度为(d,n)的矩阵,其中第j列是U矩阵中第j个矩阵和V的第j列矩阵相乘的结果。

enter image description here

一种可能的方法是:

for j in range(n):
    res[:,j] = U[:,:,j] * V[:,j]

我想知道是否有更快的方法,可以使用numpy库。特别是我在考虑np.tensordot()函数。
这个小片段允许我将单个矩阵乘以标量,但对向量的明显推广并没有返回我所希望的结果。
a = np.array(range(1, 17))
a.shape = (4,4)
b = np.array((1,2,3,4,5,6,7))
r1 = np.tensordot(b,a, axes=0)

任何建议?

你用什么软件来绘制图像? - hlin117
2
@hlin117 - 我使用了Keynote。 - Matteo
1个回答

12

您可以通过几种方式来实现这一点。首先想到的是 np.einsum

# some fake data
gen = np.random.RandomState(0)
ni, nj, nk = 10, 20, 100
U = gen.randn(ni, nj, nk)
V = gen.randn(nj, nk)

res1 = np.zeros((ni, nk))
for k in range(nk):
    res1[:,k] = U[:,:,k].dot(V[:,k])

res2 = np.einsum('ijk,jk->ik', U, V)

print(np.allclose(res1, res2))
# True

np.einsum使用爱因斯坦符号来表示张量收缩。在上面的表达式'ijk,jk -> ik'中,ijk是下标,对应于UV的不同维度。每个逗号分隔的组对应于传递给np.einsum的操作数之一(在本例中,U具有维度ijk,而V具有维度jk)。'->ik'部分指定输出数组的维度。任何具有未出现在输出字符串中的下标的维度都将被求和。

np.einsum非常有用,可以执行复杂的张量收缩,但需要一段时间才能完全理解其工作原理。您应该查看文档中的示例(如上所述)。


一些其他选项:
  1. Element-wise multiplication with broadcasting, followed by summation:

    res3 = (U * V[None, ...]).sum(1)
    
  2. inner1d with a load of transposing:

    from numpy.core.umath_tests import inner1d
    
    res4 = inner1d(U.transpose(0, 2, 1), V.T)
    
一些基准测试:
In [1]: ni, nj, nk = 100, 200, 1000

In [2]: %%timeit U = gen.randn(ni, nj, nk); V = gen.randn(nj, nk)
   ....: np.einsum('ijk,jk->ik', U, V)
   ....: 
10 loops, best of 3: 23.4 ms per loop

In [3]: %%timeit U = gen.randn(ni, nj, nk); V = gen.randn(nj, nk)
(U * V[None, ...]).sum(1)
   ....: 
10 loops, best of 3: 59.7 ms per loop

In [4]: %%timeit U = gen.randn(ni, nj, nk); V = gen.randn(nj, nk)
inner1d(U.transpose(0, 2, 1), V.T)
   ....: 
10 loops, best of 3: 45.9 ms per loop

谢谢你的回答!你能否解释一下这个函数是如何工作的呢?例如,如果 U 不是 (ni,nj,nk) 而是 (nk,ni,nj),那么函数调用会发生什么变化? - Matteo
非常好的答案!非常感谢! - Matteo
看到楼主的问题,我注意到他想要将 (i,j,k) 与 (k,i) 相乘,在被接受的答案中是 ijk, jk -> ik,但应该是 ijk, ik -> ij。 - monolith
1
@wedran OP希望在U的第二轴和V的第一轴上进行缩减,而在我的示例中是下标为j。按照他最初的符号表示法,应该是dkn,kn->dn(等同于ijk,jk->ik)。 - ali_m
当前einsum链接:https://numpy.org/doc/stable/reference/generated/numpy.einsum.html - sotmot

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接