如何在 TensorFlow 2.0 中使用数据增强功能(data augmentation)来处理 tfds.load() 加载的数据?

16
我正在遵循这个指南
它展示了如何使用tfds.load()方法从新的TensorFlow数据集中下载数据集。
import tensorflow_datasets as tfds    
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs', split=list(splits),
    with_info=True, as_supervised=True)
下一步展示了如何使用 map 方法将函数应用于数据集中的每个项:
def format_example(image, label):
    image = tf.cast(image, tf.float32)
    image = image / 255.0
    # Resize the image if required
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return image, label

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

然后我们可以使用以下方式访问元素:

for features in ds_train.take(1):
  image, label = features["image"], features["label"]
for example in tfds.as_numpy(train_ds):
  numpy_images, numpy_labels = example["image"], example["label"]

然而,该指南未提及任何关于数据增强的内容。我想使用实时数据增强,类似于Keras的ImageDataGenerator类。我尝试使用:

if np.random.rand() > 0.5:
    image = tf.image.flip_left_right(image)
format_example()中还有其他类似的增强函数,但是我该如何验证它是实时增强而不是替换数据集中的原始图像?
我可以通过向tfds.load()传递batch_size=-1并使用tfds.as_numpy()将整个数据集转换为Numpy数组,但这会加载所有图像到内存中,这是不必要的。我应该能够使用train = train.prefetch(tf.data.experimental.AUTOTUNE)仅加载足够下一次训练循环所需的数据。

您可能也想查看这个答案,它展示了增强后的数据,这样您就可以更加__确信__它正在工作(而且这个例子更有说服力)。 - Szymon Maszke
1个回答

20

你的思路方向有误。

首先,使用tfds.load来下载数据,以cifar10为例(为了简单起见,我们将使用默认的TRAINTEST拆分):

import tensorflow_datasets as tfds

dataloader = tfds.load("cifar10", as_supervised=True)
train, test = dataloader["train"], dataloader["test"]

(您可以使用自定义tfds.Split对象创建验证数据集或其他数据集,有关详细信息,请参见文档

traintesttf.data.Dataset对象,因此您可以对每个对象使用mapapplybatch等函数。

以下是一个示例,我将主要使用tf.image

  • 将每个图像转换为0-1范围内的tf.float64(不要使用官方文档中的愚蠢代码片段,这种方式确保了正确的图像格式)
  • cache()结果,因为可以在每个repeat之后重复使用
  • 随机翻转每个图像的左右镜像
  • 随机更改图像对比度
  • 随机打乱数据并进行批处理
  • 重要提示:当数据集用尽时,需要重复所有步骤。这意味着在一个epoch之后,所有上述变换都将再次应用(除了被缓存的变换)。

以下是执行上述操作的代码(您可以将lambda更改为函数或functors):

train = train.map(
    lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label)
).cache().map(
    lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
    lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
    100
).batch(
    64
).repeat()

这样的tf.data.Dataset可以直接传递给Keras的fitevaluatepredict方法。

验证它是否真的像那样工作

我看你对我的解释非常怀疑,让我们通过一个例子来说明:

1. 获取小数据集

以下是一种获取单个元素的方式,虽然阅读起来不太可读和晦涩,但如果你使用过Tensorflow,应该会理解:

# Horrible API is horrible
element = tfds.load(
    # Take one percent of test and take 1 element from it
    "cifar10",
    as_supervised=True,
    split=tfds.Split.TEST.subsplit(tfds.percent[:1]),
).take(1)

2. 重复数据并检查其是否相同:

使用 Tensorflow 2.0 ,可以几乎不用愚蠢的解决方法就能完成此操作:

element = element.repeat(2)
# You can iterate through tf.data.Dataset now, finally...
images = [image[0] for image in element]
print(f"Are the same: {tf.reduce_all(tf.equal(images[0], images[1]))}")

它毫不意外地返回:

Are the same: True

3. 检查在随机增广后每次重复的数据是否有所不同

以下代码片段将单个元素重复5次并检查哪些相等,哪些不同。

element = (
    tfds.load(
        # Take one percent of test and take 1 element
        "cifar10",
        as_supervised=True,
        split=tfds.Split.TEST.subsplit(tfds.percent[:1]),
    )
    .take(1)
    .map(lambda image, label: (tf.image.random_flip_left_right(image), label))
    .repeat(5)
)

images = [image[0] for image in element]

for i in range(len(images)):
    for j in range(i, len(images)):
        print(
            f"{i} same as {j}: {tf.reduce_all(tf.equal(images[i], images[j]))}"
        )

输出结果(在我的情况下,每次运行的结果都会不同):

0 same as 0: True
0 same as 1: False
0 same as 2: True
0 same as 3: False
0 same as 4: False
1 same as 1: True
1 same as 2: False
1 same as 3: True
1 same as 4: True
2 same as 2: True
2 same as 3: False
2 same as 4: False
3 same as 3: True
3 same as 4: True
4 same as 4: True

你还可以将这些图像强制转换为 numpy 并使用 skimage.io.imshowmatplotlib.pyplot.imshow 或其他替代方法来查看这些图像。

另一个实时数据增强可视化的示例

这个答案 提供了一个更全面和易读的数据增强使用 TensorboardMNIST 的展示,也许你会想要检查一下(是的,有点自我推销,但我认为很有用)。


这里的map函数文档中可以得知:此转换将map_func应用于此数据集的每个元素,并返回一个包含转换后元素的新数据集,其顺序与它们在输入中出现的顺序相同。 - himanshurawlani
确实如此。请检查我刚刚添加的__IMPORTANT:__部分。基本上,每个增强都应用于数据的每个部分(在这种情况下是单个元素,如果之前使用了batch(),则可以是批处理,这样应该会更快),并且在运行时返回带有或不带有增强的数据(如果是随机的)。当tf.data.Dataset耗尽并且使用repeat(为了训练多个时期/无限期)时,所有操作都会重复,除了我们在第一次传递期间缓存的操作。这是否解决了混淆? - Szymon Maszke
好的,那么我如何验证在使用“repeat”时所有操作都被重复执行了呢? - himanshurawlani
2
我看到你对tensorflow没有太多信心,我可以理解。我添加了一个示例,比较了random_flip_left_right之前和之后的图像。如果您愿意,您可以以这种方式进行更广泛的测试。 - Szymon Maszke
2
谢谢提供示例!经过验证步骤后,事情变得更加清晰了。 - himanshurawlani

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接