数据集
您需要使用torch.utils.data.Dataset
结构来定义它。以下是在普通的pytorch
中如何执行此操作(我使用pillow
加载图像,并使用torchvision
将它们转换为torch.Tensor
对象):
import torch
import torchvision
from PIL import Image
class MyDataset(torch.utils.data.Dataset):
def __init__(self, dataframe):
self.dataframe = dataframe
def __len__(self):
return len(self.dataframe)
def __getitem__(self, index):
row = self.dataframe.iloc[index]
return (
torchvision.transforms.functional.to_tensor(Image.open(row["Path"])),
row["Score"],
)
dataset = MyDataset(dataframe)
或者,您可以使用torchdata
(免责声明:我是作者,这是自荐...),它允许您像这样解耦Path
和Scores
:
import torchvision
from PIL import Image
import torchdata
class ImageDataset(torchdata.datasets.FilesDataset):
def __getitem__(self, index):
return Image.open(self.files[index])
class Labels(torchdata.Dataset):
def __init__(self, scores):
super().__init__()
self.scores = scores
def __len__(self):
return len(self.scores)
def __getitem__(self, index):
return self.scores[index]
dataset = ImageDataset.from_folder("/folder", regex="*.jpg").map(
torchvision.transforms.ToTensor()
) | Labels(dataframe["Score"].to_numpy())
(或者您可以像在常规的pytorch
中那样实现,但是要从torchdata.Dataset
继承并在构造函数中调用super().__init__()
)。
torchdata
允许您通过.map
轻松缓存图像或应用其他转换,如下所示,请查看Github仓库以获取更多信息,或在评论中提问。
DataLoader
无论哪种方式,您都应该将数据集包装在torch.utils.data.DataLoader
中,以创建批次并进行迭代处理,如下所示:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
for images, scores in dataloader:
...
在循环中,可随意处理这些图像和分数。