运行时错误:给定 groups=1,大小为 [64, 3, 3, 3] 的权重,期望输入 [4, 5000, 5000, 3] 具有 3 个通道,但实际得到了 5000 个通道。

4

我的 U-Net 模型输入 5000x5000x3 的图像,但是运行时出现了上述错误。

以下是我的模型代码。

import torch
import torch.nn as nn


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )


class UNeT(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear',
                                    align_corners=True)
        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)
        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        x = self.dconv_down4(x)
        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)
        out = self.conv_last(x)
        return out


我试图执行model(inputs.unsqueeze_(0)),但是我遇到了一个不同的错误。

1个回答

11
在pytorch中,维度的顺序与您期望的不同。您的输入张量具有4x5000x5000x3shape,您将其解释为具有4个大小的批处理,每个像素具有5000x5000像素的图像,每个像素有3个通道。也就是说,您的维度是批处理-高度-宽度-通道
然而,在pytorch中,张量维度的顺序应该是: 批处理-通道-高度-宽度。也就是说,通道维度应该在宽度和高度空间维度之前。
您需要permute(转置)您的输入张量的维度来解决问题。
model(inputs.permute(0, 3, 1, 2))

更多信息请参见nn.Conv2d的文档。


晚来一步,但为什么你不能使用 torch.reshape 而不是 permute?不确定我理解 permute 有什么不同。 - turnip
@turnip,reshape/viewpermute之间有很大的区别。阅读这个答案以了解更多信息。 - Shai

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接