使用批处理预处理编写自定义的PyTorch数据加载器迭代器

3
一个典型的自定义PyTorch数据集如下所示,
class TorchCustomDataset(torch.utils.data.Dataset):

    def __init__(self, filenames, speech_labels):
        pass

    def __len__(self):
        return 100

    def __getitem__(self, idx):
        return 1, 0

在这里,使用 __getitem__,我可以读取任何文件,并针对特定文件应用任何预处理。

如果我想对整个数据批次应用一些张量级别的预处理怎么办?从技术上讲,只需要迭代数据加载器以获取批次样本并对其进行预处理即可。

但如何使用自定义数据加载器实现呢?简言之,数据加载器的 __getitem__ 等效于什么,以便在整批数据上应用某些操作?

1个回答

4

您可以重写DataLoadercollate_fn函数:该函数从基础的Dataset中获取各个项并形成批次。通过修改collate_fn,您可以在此处添加自定义预处理。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接