Tensorflow 2.0数据集和数据加载器

21

我是一个PyTorch用户,习惯使用PyTorch中的data.dataset和data.dataloader API。我正在尝试使用TensorFlow 2.0构建相同的模型,并想知道是否有类似于PyTorch中这些API的功能。

如果没有这样的API,你们中的任何人可以告诉我人们通常如何在TensorFlow中实现数据加载部分吗?我以前用过TensorFlow 1,但从未使用过dataset API。我之前进行了硬编码。我希望有一些类似于只输入索引的覆盖getitem方法的东西。

非常感谢您提前的帮助。


1
查找 tf.data.* API。 - GPhilo
2个回答

19
使用 tf.data API 时,通常也会使用 map 函数。
在 PyTorch 中,您的 __getitem__ 调用基本上从给定的 __init__ 中获取数据结构中的一个元素,并在必要时进行转换。
在 TF2.0 中,您可以使用其中一个 Dataset.from_... 函数(请参见 from_generatorfrom_tensor_slicesfrom_tensors)初始化一个 Dataset,这实际上相当于 PyTorch 的 Dataset__init__ 部分。然后,您可以调用 map 来执行您在 __getitem__ 中要进行的逐个元素操作。
TensorFlow 数据集基本上是高级迭代器,因此按设计,您不会使用索引访问它们的元素,而是通过遍历它们来访问它们。

tf.data 的使用指南非常实用,提供了丰富的示例。


2
这正是我在寻找的答案,非常感谢Mat。 - piljae.chae
@piljae.chae,我很好奇你是怎么发现 tf.data API 的呢?我喜欢 PyTorch 的 data API,但我最近被一家 TensorFlow 公司聘用了,现在需要重新学习一些东西。我真的想知道,这两个API是否可以混合使用? - rocksNwaves

5

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接