在PyTorch中加载csv和图像数据集

3

我正在使用PyTorch进行图像分类。我有一个单独的Images文件夹,以及包含图像ID和标签的train和test csv文件。我不知道如何将这些图像和ID组合起来,并转换为张量。

  1. train.csv:包含所有图像ID(例如4325.jpg、2345.jpg等),以及包含猫、狗等标签。
  2. Image_data:包含所有带有ID名称的图像。

我回答的内容类似于这里的一个答案:https://stackoverflow.com/a/72337742/19173781 - Ahmed Eisa
1个回答

8
你可以通过继承 PyTorch 的 torch.utils.data.Dataset 来创建自定义数据集类。
以下自定义数据集类的假设为:
  • csv 文件格式为

filename label
4325.jpg cat
2345.jpg dog
  • 所有图片都在images文件夹中。
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, csv_path, images_folder, transform = None):
        self.df = pd.read_csv(csv_path)
        self.images_folder = images_folder
        self.transform = transform
        self.class2index = {"cat":0, "dog":1}

    def __len__(self):
        return len(self.df)
    def __getitem__(self, index):
        filename = self.df[index, "FILENAME"]
        label = self.class2index[self.df[index, "LABEL"]]
        image = PIL.Image.open(os.path.join(self.images_folder, filename))
        if self.transform is not None:
            image = self.transform(image)
        return image, label
        

现在,您可以使用这个类来加载训练和测试数据集,无论是使用csv文件还是图像文件夹。

train_dataset = CustomDataset("path - to - train.csv", "path - to - images - folder"  )
test_dataset = CustomDataset("path - to - test.csv", "path - to - images - folder"  )


image, label = train_dataset[0]

先生,为什么第14行......如果self.transform不是None,则显示语法错误。 - rts
@RitikKumar,是的,我在第13行忘记了。我已经修复了。 - Mitiku
2
伟大的答案 - 对我唯一的问题是 df[index, "SDFSDF"] 引发了错误 - 相反我使用了 df.loc[index]["SDFSDF"],它可以正常工作。 - aaronsnoswell

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接