我正在遵循这个指南。
它展示了如何使用
我可以通过向
它展示了如何使用
tfds.load()
方法从新的TensorFlow数据集中下载数据集。import tensorflow_datasets as tfds
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)
(raw_train, raw_validation, raw_test), metadata = tfds.load(
'cats_vs_dogs', split=list(splits),
with_info=True, as_supervised=True)
下一步展示了如何使用 map 方法将函数应用于数据集中的每个项:def format_example(image, label):
image = tf.cast(image, tf.float32)
image = image / 255.0
# Resize the image if required
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)
然后我们可以使用以下方式访问元素:
for features in ds_train.take(1):
image, label = features["image"], features["label"]
for example in tfds.as_numpy(train_ds):
numpy_images, numpy_labels = example["image"], example["label"]
然而,该指南未提及任何关于数据增强的内容。我想使用实时数据增强,类似于Keras的ImageDataGenerator类。我尝试使用:
if np.random.rand() > 0.5:
image = tf.image.flip_left_right(image)
在format_example()
中还有其他类似的增强函数,但是我该如何验证它是实时增强而不是替换数据集中的原始图像?我可以通过向
tfds.load()
传递batch_size=-1
并使用tfds.as_numpy()
将整个数据集转换为Numpy数组,但这会加载所有图像到内存中,这是不必要的。我应该能够使用train = train.prefetch(tf.data.experimental.AUTOTUNE)
仅加载足够下一次训练循环所需的数据。