用TensorFlow从NumPy数组创建数据集。

21

TensorFlow提供了一种很好的数据存储方式。例如,它被用来存储MNIST数据:

>>> mnist
<tensorflow.examples.tutorials.mnist.input_data.read_data_sets.<locals>.DataSets object at 0x10f930630>

假设有一个输入和输出的numpy数组。
假設有一個輸入和輸出的numpy數組。
>>> x = np.random.normal(0,1, (100, 10))
>>> y = np.random.randint(0, 2, 100)

我如何将它们转换为tf数据集呢?我想使用像next_batch这样的函数。
3个回答

9
数据集对象只是MNIST教程的一部分,而不是主要的TensorFlow库。
您可以在此处查看其定义: GitHub链接 构造函数接受图像和标签参数,因此您可以在那里传递自己的值。

好的,谢谢。我一直怀疑这个问题。我认为它作为主库的一部分会是一个有用的工具。据我所知,对numpy数组进行任何批量操作都需要复制数据。这可能会导致算法变慢。 - Donbeo
哲学是,TensorFlow 应该只是一个核心数学库,但其他开源库可以提供用于机器学习的额外抽象。类似于 Theano,它有像 Pylearn2 这样的库构建在其上。如果您想避免复制操作,则可以使用基于队列的数据访问功能,而不是馈送占位符。 - Ian Goodfellow

3

最近,Tensorflow添加了一个功能到其数据集API中,可以使用numpy数组。有关详细信息,请参见此处

这里是我从那里复制的片段:

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

在tf2中,这种方法不再起作用。您知道tf2中推荐的方法是什么吗? - Anatoly Alekseev
有关TF2的内容,请查看此链接:https://www.tensorflow.org/guide/data#consuming_numpy_arrays - MajidL
@MajidL 你知道如果整个数据集都无法放入内存中,该怎么做吗? - StopReadingThisUsername
你的数据集是NumPy格式,但无法加载到内存中?如果是这种情况,这个解决方案可能会有所帮助。 - MajidL

0
作为替代方案,您可以使用函数tf.train.batch()来创建数据批次,并同时消除对tf.placeholder的使用。有关更多详细信息,请参阅文档。
>>> images = tf.constant(X, dtype=tf.float32) # X is a np.array
>>> labels = tf.constant(y, dtype=tf.int32)   # y is a np.array
>>> batch_images, batch_labels = tf.train.batch([images, labels], batch_size=32, capacity=300, enqueue_many=True)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接