Pytorch Dataloader如何处理可变大小的数据?

27

我有一个数据集,看起来像下面这样。第一项是用户ID,后面是用户点击过的项目集合。

0   24104   27359   6684
0   24104   27359
1   16742   31529   31485
1   16742   31529
2   6579    19316   13091   7181    6579    19316   13091
2   6579    19316   13091   7181    6579    19316
2   6579    19316   13091   7181    6579    19316   13091   6579
2   6579    19316   13091   7181    6579
4   19577   21608
4   19577   21608
4   19577   21608   18373
5   3541    9529
5   3541    9529
6   6832    19218   14144
6   6832    19218
7   9751    23424   25067   12606   26245   23083   12606

我定义了一个自定义数据集来处理我的点击日志数据。

import torch.utils.data as data
class ClickLogDataset(data.Dataset):
    def __init__(self, data_path):
        self.data_path = data_path
        self.uids = []
        self.streams = []

        with open(self.data_path, 'r') as fdata:
            for row in fdata:
                row = row.strip('\n').split('\t')
                self.uids.append(int(row[0]))
                self.streams.append(list(map(int, row[1:])))

    def __len__(self):
        return len(self.uids)

    def __getitem__(self, idx):
        uid, stream = self.uids[idx], self.streams[idx]
        return uid, stream

然后我使用DataLoader从数据中检索小批量进行训练。
from torch.utils.data.dataloader import DataLoader
clicklog_dataset = ClickLogDataset(data_path)
clicklog_data_loader = DataLoader(dataset=clicklog_dataset, batch_size=16)

for uid_batch, stream_batch in stream_data_loader:
    print(uid_batch)
    print(stream_batch)

上面的代码返回的结果与我期望的不同,我希望stream_batch是一个长度为16的整数类型的二维张量。然而,我得到的是一个长度为16的一维张量列表,而且这个列表只有一个元素,就像下面这样。为什么会这样呢?

#stream_batch
[tensor([24104, 24104, 16742, 16742,  6579,  6579,  6579,  6579, 19577, 19577,
        19577,  3541,  3541,  6832,  6832,  9751])]

跨贴:Pytorch Dataloader 如何处理可变大小的数据? - Charlie Parker
3个回答

18

那么,如何处理样本长度不同的问题呢?torch.utils.data.DataLoader有一个collate_fn参数,用于将一组样本转化成一个批次。默认情况下,它会对列表进行处理,实现方式可以参考此处此处。您可以编写自己的collate_fn函数,例如对输入进行0填充,将其截断为某个预定义长度,或者应用任何您选择的其他操作。


如果我不想填充额外的数字怎么办?我的意思是,如果我有一个完全卷积神经网络,并且我不需要相同大小的输入,特别是我不想通过填充来改变输入(我正在进行可解释的AI实验)? - Black Jack 21
@RedFloyd 没问题,只是你需要进行一些调整并且会失去一些性能。在 PyTorch(以及大多数其他框架)中,CNN 操作(例如 Conv2d)以“向量化”的方式在第一个维度上执行(通常称为批处理维度)。在你的情况下,你只需要将这个维度设置为1,并且根据你有多少张图片来调用你的网络,而不是将它们堆叠成一个大张量,然后在所有图片上执行一次网络。这可能会损失一些性能,但不会有更多的影响。 - Jatentaki
谢谢回复。只是为了澄清,这样做本质上就是 SGD,会产生噪声和训练困难(即可能无法收敛)吗? - Black Jack 21

16
这是我处理的方式:
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    ## get sequence lengths
    lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device)
    ## padd
    batch = [ torch.Tensor(t).to(device) for t in batch ]
    batch = torch.nn.utils.rnn.pad_sequence(batch)
    ## compute mask
    mask = (batch != 0).to(device)
    return batch, lengths, mask

然后我将其作为collate_fn传递给dataloader类。


在pytorch论坛中有一个巨大的不同帖子列表,让我把它们都链接起来。它们都有自己的答案和讨论。对我来说,似乎没有一种"标准方法",但如果有权威参考,请分享。

理想答案应该提到以下内容:

  • 效率,例如在GPU中使用torch在收集函数中处理与numpy相比

这样的事情。

列表:

分桶: - https://discuss.pytorch.org/t/tensorflow-esque-bucket-by-sequence-length/41284


2
在collate中将张量放在GPU上是惯例吗?我认为如果这样做,就不能在数据加载器中使用多个worker。我很想知道哪种方法通常具有更好的性能。 - Tahlor
@Pinocchio 为什么要计算序列长度和掩码?如果我理解正确,一旦批次传递到网络中,网络就没有使用掩码或修剪输入的方法,对吗? - financial_physician
如果有人偶然发现这篇文章,我认为David Ng提供的答案是实现此操作的最佳方式。https://dev59.com/KlUK5IYBdhLWcg3w_z20 - financial_physician

9
正如 @Jatentaki 建议的那样,我编写了自定义排序函数,并且它运行良好。
def get_max_length(x):
    return len(max(x, key=len))

def pad_sequence(seq):
    def _pad(_it, _max_len):
        return [0] * (_max_len - len(_it)) + _it
    return [_pad(it, get_max_length(seq)) for it in seq]

def custom_collate(batch):
    transposed = zip(*batch)
    lst = []
    for samples in transposed:
        if isinstance(samples[0], int):
            lst.append(torch.LongTensor(samples))
        elif isinstance(samples[0], float):
            lst.append(torch.DoubleTensor(samples))
        elif isinstance(samples[0], collections.Sequence):
            lst.append(torch.LongTensor(pad_sequence(samples)))
    return lst

stream_dataset = StreamDataset(data_path)
stream_data_loader = torch.utils.data.dataloader.DataLoader(dataset=stream_dataset,                                                         
                                                            batch_size=batch_size,                                            
                                                        collate_fn=custom_collate,
                                                        shuffle=False)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接