运行时错误:形状为[1,224,224]的输出与广播形状[3,224,224]不匹配。

18

当我尝试训练网络时,我遇到了这个错误。

我们使用的类来存储Caltech 101数据集中的图像是由我们的老师提供的。

from torchvision.datasets import VisionDataset

from PIL import Image

import os
import os.path
import sys


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


class Caltech(VisionDataset):
    def __init__(self, root, split='train', transform=None, target_transform=None):
        super(Caltech, self).__init__(root, transform=transform, target_transform=target_transform)

        self.split = split # This defines the split you are going to use
                           # (split files are called 'train.txt' and 'test.txt')

        '''
        - Here you should implement the logic for reading the splits files and accessing elements
        - If the RAM size allows it, it is faster to store all data in memory
        - PyTorch Dataset classes use indexes to read elements
        - You should provide a way for the __getitem__ method to access the image-label pair
          through the index
        - Labels should start from 0, so for Caltech you will have lables 0...100 (excluding the background class) 
        '''
        # Open file in read only mode and read all lines
        file = open(self.split, "r")
        lines = file.readlines()

        # Filter out the lines which start with 'BACKGROUND_Google' as asked in the homework
        self.elements = [i for i in lines if not i.startswith('BACKGROUND_Google')]

        # Delete BACKGROUND_Google class from dataset labels
        self.classes = sorted(os.listdir(os.path.join(self.root, "")))
        self.classes.remove("BACKGROUND_Google")


    def __getitem__(self, index):
        ''' 
        __getitem__ should access an element through its index
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        '''

        img = Image.open(os.path.join(self.root, self.elements[index].rstrip()))

        target = self.classes.index(self.elements[index].rstrip().split('/')[0])

        image, label = img, target # Provide a way to access image and label via index
                           # Image should be a PIL Image
                           # label can be int

        # Applies preprocessing when accessing the image
        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def __len__(self):
        '''
        The __len__ method returns the length of the dataset
        It is mandatory, as this is used by several other components
        '''
        # Provides a way to get the length (number of elements) of the dataset
        length =  len(self.elements)
        return length

这段代码完成了预处理阶段:

# Define transforms for training phase
train_transform = transforms.Compose([transforms.Resize(256),      # Resizes short size of the PIL image to 256
                                      transforms.CenterCrop(224),  # Crops a central square patch of the image
                                                                   # 224 because torchvision's AlexNet needs a 224x224 input!
                                                                   # Remember this when applying different transformations, otherwise you get an error
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalizes tensor with mean and standard deviation
])
# Define transforms for the evaluation phase
eval_transform = transforms.Compose([transforms.Resize(256),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))   

最终,这是准备数据集和数据加载器的过程:
# Clone github repository with data
if not os.path.isdir('./Homework2-Caltech101'):
  !git clone https://github.com/MachineLearning2020/Homework2-Caltech101.git

# Commands to execute when there is an error saying no file or directory related to ./Homework2-Caltech101/
# !rm -r ./Homework2-Caltech101/
# !git clone https://github.com/MachineLearning2020/Homework2-Caltech101.git

DATA_DIR = 'Homework2-Caltech101/101_ObjectCategories'
SPLIT_TRAIN = 'Homework2-Caltech101/train.txt'
SPLIT_TEST = 'Homework2-Caltech101/test.txt'


# 1 - Data preparation
myTrainDS = Caltech(DATA_DIR, split = SPLIT_TRAIN, transform=train_transform)
myTestDS = Caltech(DATA_DIR, split = SPLIT_TEST, transform=eval_transform)

print('My Train DS: {}'.format(len(myTrainDS)))
print('My Test DS: {}'.format(len(myTestDS)))

# 1 - Data preparation
myTrain_dataloader = DataLoader(myTrainDS, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
myTest_dataloader = DataLoader(myTestDS, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

好的,现在这两个.txt文件包含了我们想要在训练和测试集中使用的图像列表,所以我们需要从那里获取它们,但是应该已经正确地完成了。问题是当我进入训练阶段时(稍后见代码),我会看到标题中的错误。我已经尝试在transform函数中添加以下行:

[...]
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),

在进行中心裁剪后,但是出现了“Image has no attribute repeat”的错误,我有些困惑。

导致错误的训练代码行如下:

# Iterate over the dataset
  for images, labels in myTrain_dataloader:

如有需要,完整的错误信息如下:
RuntimeError                              Traceback (most recent call last)

<ipython-input-197-0e4710a9855d> in <module>()
     47 
     48   # Iterate over the dataset
---> 49   for images, labels in myTrain_dataloader:
     50 
     51     # Bring data over the device of choice

2 frames

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    817             else:
    818                 del self._task_info[idx]
--> 819                 return self._process_data(data)
    820 
    821     next = __next__  # Python 2 compatibility

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
    844         self._try_put_index()
    845         if isinstance(data, ExceptionWrapper):
--> 846             data.reraise()
    847         return data
    848 

/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
    383             # (https://bugs.python.org/issue2651), so we work around it.
    384             msg = KeyErrorMessage(msg)
--> 385         raise self.exc_type(msg)

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-180-0b00b175e18c>", line 72, in __getitem__
    image = self.transform(image)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 70, in __call__
    img = t(img)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 175, in __call__
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 217, in normalize
    tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]

我正在使用Alexnet,我得到的代码如下:

    net = alexnet() # Loading AlexNet model

# AlexNet has 1000 output neurons, corresponding to the 1000 ImageNet's classes
# We need 101 outputs for Caltech-101
net.classifier[6] = nn.Linear(4096, NUM_CLASSES) # nn.Linear in pytorch is a fully connected layer
                                                 # The convolutional layer is nn.Conv2d

# We just changed the last layer of AlexNet with a new fully connected layer with 101 outputs
# It is mandatory to study torchvision.models.alexnet source code
1个回答

26
张量的第一个维度表示颜色,因此你的错误意味着你提供了一个灰度图像(1个通道),而数据加载器期望一个RGB图像(3个通道)。你定义了一个返回RGB图像的pil_loader函数,但你从未使用过它。
所以你有两个选项:
  1. 使用灰度图像代替RGB图像,从计算上讲更为便宜。 解决方案: 在训练和测试变换中,将 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 更改为 transforms.Normalize((0.5), (0.5))

  2. 确保你的图像是RGB格式的。我不知道你的图像是如何存储的,但我猜你下载的数据集是灰度的。你可以尝试使用你定义的pil_loader函数。在你的__getitem__函数中,尝试将 img = Image.open(os.path.join(self.root, self.elements[index].rstrip())) 更改为 img = pil_loader(os.path.join(self.root, self.elements[index].rstrip()))

让我知道它的进展如何!

解决方案二完美地运行了!非常感谢! 但是pil_loader和open之间有什么区别呢? - Giorgio Maritano
2
pil_loader 以彩色模式打开图像,而 image.open 则将其作为灰度图读取。 - Raimundo Manterola
1
方案1对我来说非常完美。因为之前我已经将RGB转换为灰度。 - codexaxor
好答案。解决方案1修复了它。 - undefined

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接