PyTorch的transforms.RandomRotation()在Google Colab上不起作用。

4

通常我在自己的计算机上从事字母和数字识别方面的工作,并且希望将我的项目迁移到Colab,但不幸的是出现了错误(您可以看到下面的错误信息)。 经过一些调试后,我找到了导致错误的代码行。

transforms.RandomRotation(degrees=(90, -90))

我写了一个简单的抽象代码来展示这个错误。在我的电脑环境下,这段代码可以正常工作,但是在Colab上无法工作。问题可能是由于我电脑上的PyTorch版本为1.3.1,而Colab使用的版本为1.4.0。

import torch
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt   
    transformOpt = transforms.Compose([
            transforms.RandomRotation(degrees=(90, -90)),
            transforms.ToTensor()
        ])

    train_set = datasets.MNIST(
        root='', train=True, transform=transformOpt, download=True)
    test_set = datasets.MNIST(
        root='', train=False, transform=transformOpt, download=True)


    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=100,
        shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=100,
        shuffle=False)

    images, labels = next(iter(train_loader))
    plt.imshow(images[0].view(28, 28), cmap="gray")
    plt.show()

当我在Google Colab上运行上述示例代码时,遇到了完整的错误信息。
TypeError                                 Traceback (most recent call last)

<ipython-input-1-8409db422154> in <module>()
     24     shuffle=False)
     25 
---> 26 images, labels = next(iter(train_loader))
     27 plt.imshow(images[0].view(28, 28), cmap="gray")
     28 plt.show()

10 frames

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py in __getitem__(self, index)
     95 
     96         if self.transform is not None:
---> 97             img = self.transform(img)
     98 
     99         if self.target_transform is not None:

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
     68     def __call__(self, img):
     69         for t in self.transforms:
---> 70             img = t(img)
     71         return img
     72 

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)    1001         angle = self.get_params(self.degrees)    1002 
-> 1003         return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)    1004     1005     def
__repr__(self):

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in rotate(img, angle, resample, expand, center, fill)
    727         fill = tuple([fill] * 3)
    728 
--> 729     return img.rotate(angle, resample, expand, center, fillcolor=fill)
    730 
    731 

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in rotate(self, angle, resample, expand, center, translate, fillcolor)    2003         w, h = nw, nh    2004 
-> 2005         return self.transform((w, h), AFFINE, matrix, resample, fillcolor=fillcolor)    2006     2007     def save(self,    fp, format=None, **params):

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in transform(self, size, method, data, resample, fill, fillcolor)    2297             raise ValueError("missing method data")    2298 
-> 2299         im = new(self.mode, size, fillcolor)    2300         if method == MESH:    2301             # list of quads

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in new(mode, size, color)    2503         im.palette = ImagePalette.ImagePalette()    2504         color = im.palette.getcolor(color)
-> 2505     return im._new(core.fill(mode, size, color))    2506     2507 

TypeError: function takes exactly 1 argument (3 given)

你能粘贴完整的错误跟踪吗? - kHarshit
@kHarshit 谢谢,我添加了完整的错误跟踪。顺便说一下,您可以将此示例代码粘贴到您自己的colab中并查看错误。我发现错误来自我提到的那行,但我不知道如何解决它。 - mert kaan
1个回答

6
您说得对。torchvision 0.5在RandomRotation()函数的fill参数上有一个错误,可能是由于Pillow版本不兼容所致。这个问题已经被解决了(PR#1760),并将在下一个版本中解决。
暂时地,在RandomRotation变换中添加fill=(0,)来解决这个问题。
transforms.RandomRotation(degrees=(90, -90), fill=(0,))

嗨,kHarshit,你使用的是哪个版本的PIL?我用的是“5.0.0”,但传递 fill=(0,) 并不能解决我的问题...谢谢。 - Luis Candanedo
1
@LuisCandanedo,您现在可以升级到torchvision v0.6.0,或者查看GitHub问题页面。我不确定使用的PIL版本是什么。 - kHarshit
1
我只想说,这个 bug 在 2020 年仍然存在,但至少你在这里列出的解决方法仍然有效。 - rocksNwaves

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接