在pytorch中连接两个形状不同的torch张量

3
我有两个火炬张量。一个形状为[64, 4, 300],另一个形状为[64, 300]。如何将这两个张量连接起来以获得形状为[64, 5, 300]的结果张量?我知道可以使用tensor.cat函数来实现此操作,但是为了使用该函数,我需要重新调整第二个张量的形状以匹配张量的维数。我听说不应该对张量进行重塑,因为这可能会破坏张量中的数据。那么我该如何进行这种连接操作?
我已经尝试过重新调整形状,但以下部分让我更加怀疑这种重塑。
a = torch.rand(64,300)

a1 = a.reshape(64,1,300)

list(a1[0]) == list(a)
Out[32]: False
1个回答

7
你需要使用torch.cat函数沿着第一维度拼接,并且还需要在第一个维度进行unsqueeze操作,具体如下:
import torch

first = torch.randn(64, 4, 300)
second = torch.randn(64, 300)

torch.cat((first, second.unsqueeze(dim=1)), dim=1)
# Shape: [64, 5, 300]

它不会搞乱您的数据,只是添加表面的 1 维度(如果正确执行,reshape也不会搞乱)。


谢谢您的回答。我在这里不应该使用reshape,对吗?另外,unsqueeze是什么意思?为什么可以安全使用? - Russ Brown
2
你也可以使用reshape(64, 1, 300),但这样会显得冗长。unsqueeze64300之间添加了一个1维,使形状可广播。这很安全,因为它只改变了内部数据表示(它被保留在连续的C++数组中),而且还考虑到了步幅等因素。数据只有在无法以这种非侵入性的方式“重塑”时才会被复制(如果您不希望发生复制,应该使用view)。但你必须知道它是如何工作的,这样你就不会意外地混淆尺寸。 - Szymon Maszke

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接