PyTorch中的unsqueeze函数是做什么用的?

4
lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), 
                         set_2[:, :2].unsqueeze(0))   #(n1, n2, 2)

这段代码片段中针对两个张量使用了不同的操作:unsqueeze(1)unsqueeze(0)。它们之间有什么区别呢?

3
这个问题的回答是否对您有帮助?“unsqueeze”在Pytorch中是什么意思? - akshayk07
PyTorch支持numpy风格的广播语义。这就解释了为什么当两个参数沿不同维度未压缩时,您会得到lower_bounds的观察形状。 - jodag
参数是添加维度的方向。有关更多信息,请参阅文档。 - muman
1个回答

3
“unsqueeze”函数将一个n维张量转换为n+1维张量,通过添加一个深度为零的额外维度。然而,由于新维度应该沿哪个轴线(即它应该被“挤压”的方向)是不确定的,因此需要通过dim参数来指定。
因此,得到的未挤压张量具有相同的信息,但用于访问它们的索引是不同的。
这是对于一个有效的二维矩阵,使用“squeeze”/“unsqueeze”函数进行操作时的可视化表示,其中从一个2d张量转换为3d张量,因此有3种选择新维度的位置:

enter image description here


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接