torch.nn.functional.interpolate(): 参数设置

3

我正在使用torch.nn.functional.interpolate()来调整图像大小。

首先,我使用transforms.ToTensor()将图像转换为张量,张量的大小为(3, 252, 252),其中(252, 252)是导入图像的大小。我要做的是使用interpolate()函数创建一个大小为(3, 504, 504)的张量。

我设置参数scale_factor=2,但它返回了一个(3, 252, 504)的张量。然后我将其设置为scale_factor=(1,2,2),收到了维度冲突的错误信息:

size shape must match input shape. Input is 1D, size is 3

那么我应该如何设置参数才能接收到(3, 504, 504)的张量呢?

1个回答

6
如果你使用了 scale_factor,那么你需要提供一批图像而不是单个图像。因此,你需要通过使用 unsqueeze(0) 来添加一个批次,然后将其传递给以下的 interpolate 函数:
import torch
import torch.nn.functional as F

img = torch.randn(3, 252, 252)  # torch.Size([3, 252, 252])
img = img.unsqueeze(0)  # torch.Size([1, 3, 252, 252])

out = F.interpolate(img, scale_factor=(2, 2), mode='nearest')
print(out.size()) # torch.Size([1, 3, 504, 504])

1
非常感谢你的帮助,我已经解决了! - MaiTruongSon

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接