Pytorch中DataLoader的洗牌顺序

3
我很困惑pytorch中DataLoader的洗牌顺序。 假设我有一个数据集:
datasets = [0,1,2,3,4]

在场景一中,代码如下:
torch.manual_seed(1)

G = torch.Generator()
G.manual_seed(1)

ran_sampler = RandomSampler(data_source=datasets,generator=G)
dataloader = DataLoader(dataset=datasets,sampler=ran_sampler)

洗牌结果为0,4,2,3,1


在方案二中,代码如下:

torch.manual_seed(1)

G = torch.Generator()
G.manual_seed(1)

ran_sampler = RandomSampler(data_source=datasets)
dataloader = DataLoader(dataset=datasets, sampler=ran_sampler, generator=G)

洗牌的结果为1,3,4,0,2


在第三种情况下,代码如下:

torch.manual_seed(1)

G = torch.Generator()
G.manual_seed(1)

ran_sampler = RandomSampler(data_source=datasets, generator=G)
dataloader = DataLoader(dataset=datasets, sampler=ran_sampler, generator=G)

这里的洗牌结果是 4,1,3,0,2

有人能解释一下这是怎么回事吗?

1个回答

3

根据您的代码,我对场景II进行了一些修改和检查:

datasets = [0,1,2,3,4]

torch.manual_seed(1)
G = torch.Generator()
G = G.manual_seed(1)

ran_sampler = RandomSampler(data_source=datasets, generator=G)
dataloader = DataLoader(dataset=datasets, sampler=ran_sampler)
print(id(dataloader.generator)==id(dataloader.sampler.generator))
xs = []
for x in dataloader:
    xs.append(x.item())
print(xs)

torch.manual_seed(1)
G = torch.Generator()
G.manual_seed(1)

# this is different from OP's scenario II because in that case the ran_sampler is not initialized with the right generator.
dataloader = DataLoader(dataset=datasets, shuffle=True, generator=G)
print(id(dataloader.generator)==id(dataloader.sampler.generator))
xs = []
for x in dataloader:
    xs.append(x.item())
print(xs)

torch.manual_seed(1)
G = torch.Generator()
G.manual_seed(1)


ran_sampler = RandomSampler(data_source=datasets, generator=G)
dataloader = DataLoader(dataset=datasets, sampler=ran_sampler, generator=G)
print(id(dataloader.generator)==id(dataloader.sampler.generator))
xs = []
for x in dataloader:
    xs.append(x.item())
print(xs)

输出结果为:
False
[0, 4, 2, 3, 1]
True
[4, 1, 3, 0, 2]
True
[4, 1, 3, 0, 2]

以上三个看似相同的设置导致不同结果的原因是,实际上在DataLoader内部使用了两个不同的生成器,其中第一个情况下的一个生成器为None
为了清晰起见,让我们分析一下源代码。似乎generator不仅决定了DataLoader内部的_index_sampler的随机数生成,还影响了_BaseDataLoaderIter的初始化。请参考源代码。
        if sampler is None:  # give default samplers
            if self._dataset_kind == _DatasetKind.Iterable:
                # See NOTE [ Custom Samplers and IterableDataset ]
                sampler = _InfiniteConstantSampler()
            else:  # map-style
                if shuffle:
                    sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
                else:
                    sampler = SequentialSampler(dataset)  # type: ignore[arg-type]

并且

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.generator = generator

并且

    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

并且

class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        ...
        self._index_sampler = loader._index_sampler
  • 场景二&场景三

这两种设置是等价的。我们向 DataLoader 传递一个生成器,不指定 samplerDataLoader 自动创建一个 RandomSampler 对象,并将相同的生成器分配给 self.generator

  • 场景一

我们向 DataLoader 传递正确的生成器到采样器中,但没有在 DataLoader.__init__(...) 中显式指定关键字参数 generatorDataLoader 使用给定的采样器初始化采样器,但对于 self.generator 和由 self._get_iterator() 返回的 _BaseDataLoaderIter 对象使用默认生成器 None


非常感谢您的帮助。不过,我有一个问题。为什么您说在我的情况II中ran_sampler没有使用正确的生成器进行初始化?您是什么意思? - liaoming999
在您的情况II中,ran_sampler = RandomSampler(data_source=datasets)。您没有在ran_sampler初始化中指定生成器关键字参数,因此它是使用默认生成器进行初始化的。 - TQCH
很抱歉,但是您所说的默认生成器是什么意思?在我的场景II中,我发现id(dataloader.generator)!=id(dataloader.sampler.generator)。实际上,我检查了RandomSampler()的源代码,并发现采样器会自己生成一个生成器。这正是我场景II中发生的情况。 - liaoming999
是的,那就是我想表达的意思。默认情况下,它会创建一个生成器,而这不是你想要的。 - TQCH

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接