PyTorch DataLoader shuffle

6

我进行了一项实验,但结果并非我预期的。

对于第一部分,我正在使用

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=False, num_workers=0)

在训练模型之前,我将trainloader.dataset.targets保存到变量a中,将trainloader.dataset.data保存到变量b中。然后,我使用trainloader训练模型。
训练完成后,我将trainloader.dataset.targets保存到变量c中,将trainloader.dataset.data保存到变量d中。最后,我检查a == cb == d,它们都返回True,这是预期的,因为DataLoader的shuffle参数为False

对于第二部分,我使用

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=True, num_workers=0)

我在训练模型之前将trainloader.dataset.targets保存到变量e中,将trainloader.dataset.data保存到变量f中。然后,我使用trainloader来训练模型。训练完成后,我将trainloader.dataset.targets保存到变量g中,将trainloader.dataset.data保存到变量h中。由于shuffle=True,我期望e == gf == h都为False,但它们再次给出了True。我从DataLoader类的定义中漏掉了什么?
2个回答

5

我相信存储在trainloader.dataset.data或.target中的数据不会被随机打乱,只有当DataLoader被作为生成器或迭代器调用时,数据才会被随机打乱。

你可以通过多次执行next(iter(trainloader))来进行检验,一次没有随机打乱,一次有随机打乱,这两种情况应该会给出不同的结果。

import torch
import torchvision

transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        ])
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)
dataLoader = torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = False,
                                         num_workers = 10)
target = dataLoader.dataset.targets


MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)

dataLoader_shuffled= torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = True,
                                         num_workers = 10)

target_shuffled = dataLoader_shuffled.dataset.targets

print(target == target_shuffled)

_, target = next(iter(dataLoader));
_, target_shuffled = next(iter(dataLoader_shuffled))

print(target == target_shuffled)

这将会给出:
tensor([True, True, True,  ..., True, True, True])
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False,  True, False,
        False, False, False, False, False, False, False, False])

然而,存储在数据和目标中的数据和标签是固定的列表,由于您试图直接访问它们,它们不会被洗牌。

0

在使用 Dataset 类加载数据时,我遇到了类似的问题。我停止使用 Dataset 类加载数据,改为使用以下代码,这个方法对我很有效。

X = torch.from_numpy(X)
y = torch.from_numpy(y)

train_data = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)

X和y是从CSV文件中读取的numpy数组。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接