PyTorch的DataLoader中,__getitem__的idx是如何起作用的?

15

我目前正在尝试使用PyTorch的DataLoader来处理数据以喂给我的深度学习模型,但是遇到了一些困难。

我需要的数据形状为(minibatch_size=32, rows=100, columns=41)。我编写的自定义Dataset类中的__getitem__代码大致如下:

def __getitem__(self, idx):
    x = np.array(self.train.iloc[idx:100, :])
    return x

我之所以这样写是因为我希望DataLoader每次处理形状为(100, 41)的输入实例,并且我们有32个这样的单个实例。

然而,我注意到与我最初的想法相反,DataLoader传递给函数的idx参数不是顺序的(这很关键,因为我的数据是时序数据)。例如,打印这些值会得到像这样的结果:

idx = 206000
idx = 113814
idx = 80597
idx = 3836
idx = 156187
idx = 54990
idx = 8694
idx = 190555
idx = 84418
idx = 161773
idx = 177725
idx = 178351
idx = 89217
idx = 11048
idx = 135994
idx = 15067

这是正常的行为吗?我发布这个问题是因为返回的数据批次不是我最初想要的。

在使用DataLoader之前,我使用的原始逻辑是:

  1. txtcsv文件中读取数据。
  2. 计算数据中有多少批次,并相应地切片数据。例如,由于一个输入实例的形状为(100, 41),其中32个实例组成一个小批量,我们通常会得到大约100个左右的批次,并相应地重塑数据。
  3. 一个输入的形状为(32, 100, 41)

我不确定我应该如何处理DataLoader钩子方法。任何提示或建议都将不胜感激。提前致谢。


你能详细说明一下你的“2”吗?“我们通常最终会得到大约100个”,你是指你的数据集有32 * 100个样本吗? - enamoria
嗨。不,我的意思是模型的一个输入形状为(100, 40),而有32个这样的输入组成一个小批量。 - Sean
1
@Seankala,我试图帮你理解DataLoader的代码。如果有帮助,请告诉我。 - Berriel
@Berriel 是的,它帮了很大的忙。非常感谢您花费时间和精力进行详细的解释! - Sean
1个回答

22
idx 的定义由 samplerbatch_sampler 决定,你可以在这里(开源项目是你的好朋友)看到。在这个 code(以及注释/文档字符串)中,你可以看到 samplerbatch_sampler 之间的区别。如果你在这里看一下,你会看到索引是如何选择的:
def __next__(self):
    index = self._next_index()

# and _next_index is implemented on the base class (_BaseDataLoaderIter)
def _next_index(self):
    return next(self._sampler_iter)

# self._sampler_iter is defined in the __init__ like this:
self._sampler_iter = iter(self._index_sampler)

# and self._index_sampler is a property implemented like this (modified to one-liner for simplicity):
self._index_sampler = self.batch_sampler if self._auto_collation else self.sampler

请注意,这是_SingleProcessDataLoaderIter实现;你可以在这里找到_MultiProcessingDataLoaderIter的实现(当然,使用哪个取决于num_workers值,如这里所示)。回到采样器,假设你的数据集不是_DatasetKind.Iterable,并且你没有提供自定义采样器,那么这意味着你正在使用(dataloader.py#L212-L215):
if shuffle:
    sampler = RandomSampler(dataset)
else:
    sampler = SequentialSampler(dataset)

if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)

让我们来看一下默认的BatchSampler如何构建批次

def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch

非常简单:它从采样器中获取索引,直到达到所需的batch_size为止。
现在,“PyTorch的DataLoader中的__getitem__的idx如何工作?”这个问题可以通过查看每个默认采样器的工作方式来回答。
class SequentialSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)
def __iter__(self):
    n = len(self.data_source)
    if self.replacement:
        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    return iter(torch.randperm(n).tolist())

因此,由于您没有提供任何代码,我们只能假设:
  1. 您在DataLoader中使用了shuffle=True 或者
  2. 您使用了自定义的采样器 或者
  3. 您的数据集是_DatasetKind.Iterable

一个杰出的答案! - Aidos

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接