PyTorch:逐行点积

8

假设我有两个张量:

a = torch.randn(10, 1000, 1, 4)
b = torch.randn(10, 1000, 6, 4)

第三个索引是向量的索引。

我想要针对向量a,计算每个向量在b中的点积。

具体来说,我的意思是这样的:

dots = torch.Tensor(10, 1000, 6, 1)
for b in range(10):
     for c in range(1000):
           for v in range(6):
            dots[b,c,v] = torch.dot(b[b,c,v], a[b,c,0]) 

我应如何使用torch函数实现这一点?
1个回答

13
a = torch.randn(10, 1000, 1, 4)
b = torch.randn(10, 1000, 6, 4)

c = torch.sum(a * b, dim=-1)

print(c.shape)

torch.Size([10,1000,6])

c = c.unsqueeze(-1)
print(c.shape)

大小为torch.Size([10, 1000, 6, 1])的张量


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接