如何保存PyTorch的DataLoader实例?

4
我想保存PyTorch的torch.utils.data.dataloader.DataLoader实例,以便我可以在之前保留的情况下继续训练(包括随机种子、状态等)。

请点击此处查看:https://discuss.pytorch.org/t/how-to-save-dataloader/62813/3 - Muhammad Hamza
谢谢,但是torch.save()不会保存状态。如果我保存并重新加载它,它将从新的shuffle种子开始。 - edoost
只需使用相同的洗牌种子即可从上一个时期重新开始训练。我认为您无法在时期之间重新启动。 - akshayk07
3个回答

5
你需要自定义实现采样器。 可以使用以下无麻烦的内容:https://gist.github.com/usamec/1b3b4dcbafad2d58faa71a9633eea6a5 你可以像这样保存和恢复:
sampler = ResumableRandomSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler, pin_memory=True)

for x in loader:
    print(x)
    break

sampler2 = ResumableRandomSampler(dataset)
torch.save(sampler.get_state(), "test_samp.pth")
sampler2.set_state(torch.load("test_samp.pth"))
loader2 = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler2, pin_memory=True)

for x in loader2:
    print(x)

2
很简单。一个人应该设计自己的采样器,它会获取起始索引并自行随机打乱数据:
import random
from torch.utils.data.dataloader import Sampler


random.seed(224)  # use a fixed number


class MySampler(Sampler):
    def __init__(self, data, i=0):
        random.shuffle(data)
        self.seq = list(range(len(data)))[i * batch_size:]

    def __iter__(self):
        return iter(self.seq)

    def __len__(self):
        return len(self.seq)

现在将最后的索引i保存在某个地方,下次使用它实例化DataLoader

train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset,                                                         
                               batch_size=batch_size, 
                               sampler=train_sampler,
                               shuffle=False)  # don't forget to set DataLoader's shuffle to False

在Colab上进行培训非常有用。


1
我认为在洗牌后需要进行子选择索引。否则,您将不知道已经覆盖了哪些索引。 - fsociety
@fsociety 在这个例子中不需要。我在这里使用一个恒定的随机种子,并且每个时期不会洗牌数据。因此,只需知道最后一个索引即可。然而,如果在训练时进行洗牌(或没有固定的随机种子),他们应该保留所有已覆盖的索引,正如你所建议的那样。 - edoost
我看不出有什么区别。如果您想在一个 epoch 中恢复采样器,这里就是使用案例,您不知道上次停下来的位置。所以假设您保存了 i,它是10。然后您只取大于该值的所有索引,对其进行洗牌,并选择接下来的10个。但您不知道这10个是哪些。这很容易解决,在固定种子的情况下,在洗牌后进行子选择。 - fsociety
@fsociety 正如代码所示,洗牌只在Sampler实例化时进行一次。由于随机种子是相同的数字,每次我们洗牌data(这就是固定随机种子的含义),我们将得到相同的索引顺序。在这种情况下,i始终表示相同的索引。因此,在选择第i个元素之前总是进行洗牌。 - edoost
也许我眼瞎了,但你是在选择索引或者说在对要考虑的索引进行子采样之后才洗牌数据。 - fsociety
1
@fsociety 对此表示抱歉。我很久之前写了这个答案。我已经修复了代码。感谢你指出错误。 - edoost

1

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接