如何从DataLoader中获取样本的文件名?

11

我需要用我训练的卷积神经网络的数据测试结果来写入一个文件。这些数据包括说话人的语音数据收集。文件的格式应该是“文件名,预测结果”,但我很难提取文件名。我是这样加载数据的:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

我正在尝试以下方式向文件写入:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()

os.listdir(TESTH_DATA_PATH + "/all")[i] 的问题在于它与 test_loader 加载的文件顺序不同步。我该怎么办?

3个回答

10

这取决于您的Dataset如何实现。例如,在torchvision.datasets.MNIST(...)的情况下,您无法简单地检索文件名,因为不存在单个样本的文件名(MNIST样本以不同的方式加载)。

由于您没有展示您的Dataset实现,我将告诉您如何使用torchvision.datasets.ImageFolder(...)(或任何torchvision.datasets.DatasetFolder(...))来完成此操作:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()
你可以看到文件路径是在__getitem__(self, index)中被检索出来的,具体来说就是这里
如果你实现了自己的Dataset(并且可能想支持shuffle和batch_size > 1),那么在__getitem__(...)调用时,我会返回sample_fname并执行类似以下操作的代码:
for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]

这样你就不需要关心shuffle。如果batch_size大于1,则需要将循环的内容更改为更通用的东西,例如:

这样做,您就无需担心shuffle。如果batch_size大于1,则需要更改循环内的内容以适应更通用的情况,例如:

f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("\n".join([
        ", ".join(x)
        for x in zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "\n")
f.close()

1
谢谢提示!我可以从datasets.ImageFolder.samples[i][0]中读取文件名列表。 - sailfish009
我使用了你的代码,但是我很难为我的所有图像获取文件名。你能否请看一下这里 https://stackoverflow.com/questions/71430015/accessing-the-filename-as-well-as-prediction-for-each-of-images-inside-the-test?非常感谢。 - Mona Jalal
1
@MonaJalal 这个第一个解决方案只适用于您的情况与 OP 相同的情况:batch_size=1, shuffle=False。否则,您必须使用自定义数据集进行包装,就像我在答案中建议的那样,在 __getitem__(...) 中返回 sample_fname - Berriel
即使batch_size=1且使用ImageFolder,这对我来说也不起作用。在“sample_fname,_=test_loader.dataset.samples[i]”这一行上,我会收到错误“AttributeError:'Subset'对象没有属性'samples'”。您有什么想法我做错了什么? - Ferdinando Randisi
1
@FerdinandoRandisi 根据错误提示,你的数据集不是 DatasetFolder,而是一个 Subset。在这种情况下,你需要访问额外的 .dataset 属性,并根据你的 Subset 索引使用正确的索引。 - Berriel

1
通常情况下,DataLoader 用于提供数据集中的批次。
正如@Barriel所提到的,对于单/多标签分类问题,DataLoader 没有图像文件名,只有表示图像和类别/标签的张量。
然而,DataLoader 构造函数在加载对象时可以处理一些小事情(与数据集一起打包目标/标签和文件名,如果需要还可以是一个 dataframe)。
这样,DataLoader 可以更好地获取您所需的内容。

1

如果你使用的是 PyCharm 或任何具有调试工具的 IDE,请使用它来查看你的 data_loader 内部,希望你可以看到一个文件名列表,就像我的情况一样。

在我的情况下, 我的 data_loader 是由 mmsegmentation 创建的。 Data_loader created by mmsegmentation


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接