如何从 PyTorch DataLoader 中获取特定的样本?

6
在Pytorch中,是否有一种使用torch.utils.data.DataLoader类来加载特定的单个样本的方法?我想对它进行一些测试。
教程中使用了...
trainloader = torch.utils.data.DataLoader(...)
images, labels = next(iter(trainloader))

想要获取一批随机样本。是否可以使用DataLoader来获取一个特定的样本?

谢谢!


https://dev59.com/blQJ5IYBdhLWcg3wkGy-#61389393 - Dishin H Goyani
2个回答

8
  • 关闭DataLoader中的shuffle
  • 使用batch_size计算所需样本所在的批次
  • 迭代到所需批次

代码

import torch 
import numpy as np
import itertools

X= np.arange(100)
batch_size = 2

dataloader = torch.utils.data.DataLoader(X, batch_size=batch_size, shuffle=False)
sample_at = 5
k = int(np.floor(sample_at/batch_size))

my_sample = next(itertools.islice(dataloader, k, None))
print (my_sample)

输出:

tensor([4, 5])

感谢您的回答@mujjiga,非常好用! - MJimitater
很棒的答案,正是所需之物。 - Rishabh Gupta

3

如果您想从数据集中获取特定的单个样本,可以使用Subset类进行检查。(https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset) 可以像这样:

indices =  [0,1,2]  # select your indices here as a list  
subset = torch.utils.data.Subset(train_set, indices)
trainloader = DataLoader(subset , batch_size =  16  , shuffle =False) #set shuffle to False 

for image , label in trainloader:
   print(image.size() , '\t' , label.size())
   print(image[0], '\t' , label[0]) # index the specific sample 

这是一个有用的链接,如果你想了解更多关于Pytorch数据加载工具的内容 (https://pytorch.org/docs/stable/data.html

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接