介绍
这个问题有点开放性,但是让我们试一试,并请在我哪里有错误的时候纠正我。
到目前为止,我已经首先将数据导出到文件系统中,使用文档类别命名子文件夹。
我认为这不明智,因为:
- 你实际上在复制数据
- 每当您想要仅通过代码和数据库进行新的训练时,这个操作就必须重复执行
- 您可以同时访问多个数据点,并将它们缓存在RAM中以供以后重用,而无需多次从硬盘读取(这非常耗费资源)
我说得对吗?直接连接到MongoDB是否有意义?
考虑到上面的情况,可能是的(特别是涉及清晰和可移植的实现时)。
或者有理由不这样做吗(例如,DB通常太慢等)?
据我所知,在这种情况下,DB不应该更慢,因为它将缓存访问它的数据,但不幸的是我不是DB专家。数据库已经默认实现了许多更快访问的技巧。
能否以某种方式预取数据?
是的,如果您只想获取数据,则可以一次加载更多数据(例如 1024
条记录)并从中返回数据批次(例如 batch_size=128
)。
实现
如何实现PyTorch DataLoader?我在网上只找到了很少的代码片段([1]和[2]),这让我对我的方法产生了怀疑。
我不确定为什么您要这样做。您应该使用如您所列举的示例中所示的 torch.utils.data.Dataset
。
我会从简单的非优化方法开始,类似于此处的方法:
- 在
__init__
中打开与数据库的连接,并在使用期间保持连接处于活动状态(我会从torch.utils.data.Dataset
创建上下文管理器,以便在训练轮次完成后关闭连接)。
- 我不会将结果转换为
list
(特别是由于显而易见的内存限制),因为这违背了生成器的初衷。
- 我会在此数据集内部执行批处理(这里有一个
batch_size
参数,在这里可以查看)。
- 我不确定
__getitem__
函数,但似乎它可以一次返回多个数据点,因此我会使用它,并且应该允许我们使用num_workers>0
(假设mycol.find(query)
每次都以相同的顺序返回数据)。
鉴于这些,以下是我会采取的一些措施:
class DatabaseDataset(torch.utils.data.Dataset):
def __init__(self, query, batch_size, path: str, database: str):
self.batch_size = batch_size
client = pymongo.MongoClient(path)
self.db = client[database]
self.query = query
self.length = self.db.estimated_document_count()
self.cursor = None
def __enter__(self):
self.cursor = self.db.find(self.query)
return self
def shuffle(self):
pass
def __exit__(self, *_, **__):
self.cursor.close()
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
examples = self.cursor[index * batch_size : (index + 1) * batch_size]
...
return data, labels
现在批处理由
DatabaseDataset
处理,因此
torch.utils.data.DataLoader
可以使用
batch_size=1
。您可能需要挤压附加维度。
由于
MongoDB
使用锁定(这并不奇怪,但请参见
此处),
num_workers>0
不应该是一个问题。
可能的使用方法(示意性):
with DatabaseDataset(...) as e:
dataloader = torch.utils.data.DataLoader(e, batch_size=1)
for epoch in epochs:
for batch in dataloader:
...
dataset.shuffle()
记住在这种情况下的洗牌实现!(同时可以在上下文管理器中进行洗牌,并且您可能需要手动关闭连接或类似操作)。