PyTorch中的transforms是用来做什么的?

24

我刚接触Pytorch,对CNN并不是很熟悉。 我使用他们提供的Pytorch教程成功地建立了分类器,但我不太理解在加载数据时具体在做什么。

他们对训练数据进行了一些数据增强和归一化处理,但当我尝试修改参数时,代码无法运行。

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

我是否在扩展我的训练数据集?我没有看到数据增强。

为什么我修改transforms.RandomResizedCrop(224)的值后,数据无法加载?

我需要对测试数据集进行转换吗?

我对他们所做的数据转换有些困惑。


你没有说明你遇到了什么错误。我怀疑如果你改变RandomResizedCrop生成图像的大小,当卷积和全连接部分之间的特征被压平时,你的模型会崩溃。 - Manuel Lagunas
2个回答

43

transforms.Compose会把所有提供的转换组合在一起。因此,transforms.Compose中的所有转换都会依次应用于输入数据。

训练集转换

  1. transforms.RandomResizedCrop(224):这将从输入图像中随机提取大小为(224, 224)的补丁。它可以从左上角、右下角或中间的任何位置选择。因此,在这部分进行了数据增强。另外,更改此值可能会对模型中的全连接层产生影响,因此不建议更改此值。
  2. transforms.RandomHorizontalFlip():一旦我们有了大小为(224, 224)的图像,我们可以选择翻转它。这也是数据增强的一部分。
  3. transforms.ToTensor():这只是将输入图像转换为PyTorch张量。
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):这只是输入数据的缩放,这些值(均值和标准差)必须预先计算好您的数据集。也不建议更改这些值。

验证集转换

  1. transforms.Resize(256):首先将输入图像调整为大小为(256, 256)
  2. transforms.CenterCrop(224):裁剪图像的中心部分,形状为(224, 224)

其余部分与训练集相同。

P.S.:您可以在官方文档中了解更多有关这些转换的信息。


请问您能否详细说明为什么我们不将图像大小调整为224x224,而是先调整到256,然后进行中心裁剪? - Jjang
我不太确定为什么在这个例子中要这样做。这可能是由于数据的特殊性造成的。从网络的角度来看,它期望输入大小为(224, 224),无论输入如何转换。 - layog
那么,组合中的所有转换肯定会在一次迭代(或一批次)中应用于每个样本吗?还是说在一次迭代中只应用任何一个(或少数)转换? - Ajinkya Ambatwar
所有在compose中的转换都会应用于所有的输入。 - layog

1

关于数据增强的不确定性,我建议您参考以下答案:

PyTorch中的数据增强

简而言之,假设您只有随机水平翻转变换,在迭代图像数据集时,一些图像将返回为原始图像,一些图像将返回为翻转的图像(翻转的原始图像不会返回)。换句话说,一个迭代中返回的图像数量与数据集的原始大小相同,并且没有进行增强。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接