如何在PyTorch中为子集使用不同的数据增强方法

11

如何在PyTorch中为不同的Subset使用不同的数据增强(变换)?

例如:

train, test = torch.utils.data.random_split(dataset, [80000, 2000])

train和test将使用与dataset相同的变换。如何为这些子集使用自定义变换?

4个回答

14

我的当前解决方案并不是非常优雅,但是有效:

from copy import copy

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
train_dataset.dataset = copy(full_dataset)

test_dataset.dataset.transform = transforms.Compose([
    transforms.Resize(img_resolution),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_dataset.dataset.transform = transforms.Compose([
    transforms.RandomResizedCrop(img_resolution[0]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

基本上,我正在为其中一个数据集拆分定义一个新数据集(它是原始数据集的副本),然后为每个拆分定义自定义转换。

注意:由于我使用的是ImageFolder数据集,它使用.tranform属性执行转换,因此train_dataset.dataset.transform有效。

如果有人知道更好的解决方案,请与我们分享!


2
是的,PyTorch数据集API有点基础。内置数据集没有相同的属性,一些转换仅适用于PIL图像,一些仅适用于数组,Subset不会委托给包装的数据集...我希望这种情况将来会改变,但目前我认为没有更好的方法来解决它。 - oarfish

6
这是我使用的方法(来自这里):
import torch
from torch.utils.data import Dataset, TensorDataset, random_split
from torchvision import transforms

class DatasetFromSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

这里有一个例子:

init_dataset = TensorDataset(
    torch.randn(100, 3, 24, 24),
    torch.randint(0, 10, (100,))
)

lengths = [int(len(init_dataset)*0.8), int(len(init_dataset)*0.2)]
train_subset, test_subset = random_split(init_dataset, lengths)

train_dataset = DatasetFromSubset(
    train_set, transform=transforms.Normalize((0., 0., 0.), (0.5, 0.5, 0.5))
)
test_dataset = DatasetFromSubset(
    test_set, transform=transforms.Normalize((0., 0., 0.), (0.5, 0.5, 0.5))
)

4

我已经放弃了,复制了自己的子集(几乎与pytorch相同)。我将变换保存在子集中(而不是父级)。

class Subset(Dataset):
    r"""
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset, indices, transform):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __getitem__(self, idx):
        im, labels = self.dataset[self.indices[idx]]
        return self.transform(im), labels

    def __len__(self):
        return len(self.indices)

你还需要编写自己的拆分函数。

0

你可以为每个子集使用自定义的collate_fn。 我在使用自定义数据集进行目标检测时使用了它,这样每个样本都是一个包含图像和元数据的字典:

def collate_fn_transform(transform):
        def collate_fn(batch):
            for sample in batch:
                transformed = transform(image=sample['image'], bboxes=sample['boxes'],
                                keypoints=sample['keypoints'], labels=sample['labels'])
                sample['image'] = transformed['image']
                sample['boxes'] = torch.tensor(transformed['bboxes'], dtype=torch.float32)
                sample['keypoints'] = torch.tensor(transformed['keypoints'], dtype=torch.float32).unsqueeze(0)
        return batch
    return collate_fn

indices = torch.randperm(len(dataset))
train_set = torch.utils.data.Subset(dataset, indices=indices[:train_size])
train_transform = A.Compose([...])
        
val_set = torch.utils.data.Subset(dataset, indices=indices[train_size:])
val_transform = A.Compose([...])
loaders = {
        'train': torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True,
                                             collate_fn=collate_fn_transform(train_transform),
                                             num_workers=4, pin_memory=True),
        'val': torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False,
                                           collate_fn=collate_fn_transform(val_transform))
    }



网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接