如何在Tensorflow Estimator的input_fn中执行数据增强

3
使用Tensorflow的Estimator API,在管道的哪个阶段应该执行数据增强?
根据Tensorflow官方指南,可以在input_fn中执行数据增强。具体请参考这里
def parse_fn(example):
  "Parse TFExample records and perform simple data augmentation."
  example_fmt = {
    "image": tf.FixedLengthFeature((), tf.string, ""),
    "label": tf.FixedLengthFeature((), tf.int64, -1)
  }
  parsed = tf.parse_single_example(example, example_fmt)
  image = tf.image.decode_image(parsed["image"])

  # augments image using slice, reshape, resize_bilinear
  #         |
  #         |
  #         |
  #         v
  image = _augment_helper(image)

  return image, parsed["label"]

def input_fn():
  files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
  dataset = files.interleave(tf.data.TFRecordDataset)
  dataset = dataset.map(map_func=parse_fn)
  # ...
  return dataset

我的问题

如果我在input_fn内执行数据增强,那么parse_fn会返回一个只包含原始输入图像和所有增强变体的批次还是单个[增强的]示例?如果它只应返回单个[增强的]示例,那么如何确保数据集中的所有图像都以其未增强的形式使用,以及所有变体?


将一个随机函数放到.map中 请参阅https://dev59.com/ylMI5IYBdhLWcg3w2_fm - churan lin
2个回答

1
如果您在数据集上使用迭代器,每次迭代数据集时,_augment_helper函数都会被调用,并跨越每个输入的数据块(因为您正在dataset.map中调用parse_fn)。将您的代码更改为:
  ds_iter = dataset.make_one_shot_iterator()
  ds_iter = ds_iter.get_next()
  return ds_iter

我已经使用简单的增强函数进行了测试。
  def _augment_helper(image):
       print(image.shape)
       image = tf.image.random_brightness(image,255.0, 1)
       image = tf.clip_by_value(image, 0.0, 255.0)
       return image

将255.0更改为您的数据集中的最大值,我使用255.0作为示例,因为我的数据集是以8位像素值表示的。

0

每次调用parse_fn函数时,它将返回单个示例;如果您使用.batch()操作,则会返回一批解析图像。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接