PyTorch DataLoader 使用MongoDB

4

我想知道是否将DataLoader连接到MongoDB是明智的选择,以及如何实现。

背景

我有大约2000万个文档存在(本地)MongoDB中。比内存容量更多。我想在数据上训练深度神经网络。目前,我先将数据导出到文件系统中,子文件夹命名为文档类别。但我觉得这种方法毫无意义。如果数据已经妥善保管在数据库中,为什么要先导出(后来再删除)呢?

问题1:

我对吗?直接连接MongoDB是否有意义?或者有不这样做的原因(例如,数据库通常太慢等)?如果数据库太慢(为什么?),是否可以以某种方式预取数据?

问题2:

如何实现PyTorch的 ? 我在网上只找到了很少的代码片段([1][2]) ,这让我对我的方法产生了怀疑。

代码片段

我访问MongoDB的一般方式如下所示。我认为这没有什么特别之处。

import pymongo
from pymongo import MongoClient

myclient = pymongo.MongoClient("mongodb://localhost:27017/")
mydb = myclient["xyz"]
mycol = mydb["xyz_documents"]

query = {
    # some filters
}

results = mycol.find(query)

# results is now a cursor that can run through all docs
# Assume, for the sake of this example, that each doc contains a class name and some image that I want to train a classifier on

1
如果您能提供在Python中访问MongoDB文档的最小代码,那将非常有帮助。 - Ivan
嗨,伊万,我添加了一段代码片段。 - pascal
此外,结果数量是否已知? - Ivan
大约,我们知道。可以通过mycol.estimated_document_count()非常快速地获得近似数字。更慢地,我也可以精确地计算结果 - 使用mycol.count_documents(filter)。 - pascal
第一个链接已经失效,我进行了修正。 - pascal
1个回答

4

介绍

这个问题有点开放性,但是让我们试一试,并请在我哪里有错误的时候纠正我。

到目前为止,我已经首先将数据导出到文件系统中,使用文档类别命名子文件夹。

我认为这不明智,因为:

  • 你实际上在复制数据
  • 每当您想要仅通过代码和数据库进行新的训练时,这个操作就必须重复执行
  • 您可以同时访问多个数据点,并将它们缓存在RAM中以供以后重用,而无需多次从硬盘读取(这非常耗费资源)

我说得对吗?直接连接到MongoDB是否有意义?

考虑到上面的情况,可能是的(特别是涉及清晰和可移植的实现时)。

或者有理由不这样做吗(例如,DB通常太慢等)?

据我所知,在这种情况下,DB不应该更慢,因为它将缓存访问它的数据,但不幸的是我不是DB专家。数据库已经默认实现了许多更快访问的技巧。

能否以某种方式预取数据?

是的,如果您只想获取数据,则可以一次加载更多数据(例如 1024 条记录)并从中返回数据批次(例如 batch_size=128)。

实现

如何实现PyTorch DataLoader?我在网上只找到了很少的代码片段([1]和[2]),这让我对我的方法产生了怀疑。

我不确定为什么您要这样做。您应该使用如您所列举的示例中所示的 torch.utils.data.Dataset

我会从简单的非优化方法开始,类似于此处的方法:

  • __init__中打开与数据库的连接,并在使用期间保持连接处于活动状态(我会从torch.utils.data.Dataset创建上下文管理器,以便在训练轮次完成后关闭连接)。
  • 我不会将结果转换为list(特别是由于显而易见的内存限制),因为这违背了生成器的初衷。
  • 我会在此数据集内部执行批处理(这里有一个batch_size参数,在这里可以查看)。
  • 我不确定__getitem__函数,但似乎它可以一次返回多个数据点,因此我会使用它,并且应该允许我们使用num_workers>0(假设mycol.find(query)每次都以相同的顺序返回数据)。

鉴于这些,以下是我会采取的一些措施:

class DatabaseDataset(torch.utils.data.Dataset):
    def __init__(self, query, batch_size, path: str, database: str):
        self.batch_size = batch_size

        client = pymongo.MongoClient(path)
        self.db = client[database]
        self.query = query
        # Or non-approximate method, if the approximate method
        # returns smaller number of items you should be fine
        self.length = self.db.estimated_document_count()

        self.cursor = None

    def __enter__(self):
        # Ensure that this find returns the same order of query every time
        # If not, you might get duplicated data
        # It is rather unlikely (depending on batch size), shouldn't be a problem
        # for 20 million samples anyway
        self.cursor = self.db.find(self.query)
        return self

    def shuffle(self):
        # Find a way to shuffle data so it is returned in different order
        # If that happens out of the box you might be fine without it actually
        pass

    def __exit__(self, *_, **__):
        # Or anything else how to close the connection
        self.cursor.close()

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        # Read takes long, hence if you can load a batch of documents it should speed things up
        examples = self.cursor[index * batch_size : (index + 1) * batch_size]
        # Do something with this data
        ...
        # Return the whole batch
        return data, labels

现在批处理由DatabaseDataset处理,因此torch.utils.data.DataLoader可以使用batch_size=1。您可能需要挤压附加维度。
由于MongoDB使用锁定(这并不奇怪,但请参见此处),num_workers>0不应该是一个问题。
可能的使用方法(示意性):
with DatabaseDataset(...) as e:
    dataloader = torch.utils.data.DataLoader(e, batch_size=1)
    for epoch in epochs:
        for batch in dataloader:
            # And all the stuff
            ...
        dataset.shuffle() # after each epoch

记住在这种情况下的洗牌实现!(同时可以在上下文管理器中进行洗牌,并且您可能需要手动关闭连接或类似操作)。


我之前不知道,但是我现在爱你了。感谢你提供详细的回复! - pascal
1
@pascal 这个问题比较普遍,但是希望它能让你有一个好的开始。当你实际实现它时,你可以发布更具体的问题,祝你好运。 - Szymon Maszke
@SzymonMaszke 我尝试了类似的方法,将mongodb连接添加到“Dataset”构造函数中。但是,如果num_workers> 0,则会导致“无法pickle_thread.lock”错误。这是因为需要对“Dataset”对象进行pickling以进行多处理,而mongodb连接无法被pickled。我在这里错过了什么吗? - kyc12
这是参考 pytorch 论坛上的问题 - https://discuss.pytorch.org/t/dataset-with-unpicklable-objects-breaks-dataloader-where-num-workers-0/149580 - kyc12
1
@kyc12 请查看 此答案。你可能可以在类外部创建连接(例如,通过工厂模式创建对象),并且对象本身保留连接参数(并重新创建连接)。 - Szymon Maszke

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接