Tensorflow数据集.map() API

7

关于此事还有几个问题。

有时我想在Tensorflow中执行以下操作(假设我通过加载WAV文件创建训练示例):

import tensorflow as tf 

def _some_audio_preprocessing_func(filename):
   # ... some logic here which mostly uses Tensorflow ops ...
   with tf.Session(graph=tf.Graph()) as sess:
        wav_filename_placeholder = tf.placeholder(tf.string, [])
        wav_loader = io_ops.read_file(wav_filename_placeholder)
        wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
        data = sess.run(
                [wav_decoder],
                feed_dict={wav_filename_placeholder: filename})
        return data

dataset = tf.data.Dataset.list_files('*.wav')
dataset = dataset.map(_some_preprocessing_func)
  1. 如果我有一个使用张量运算的parse_image()函数,它应该属于主图吗?根据Google自己的音频TF教程所设定的示例,看起来他们创建了一个单独的图!这难道不会破坏使用Tensorflow加速计算的目的吗?
  2. 只要任何一行代码不来自Tensorflow库,我就要使用tf.py_func()吗?再次想知道性能影响和何时使用它...

谢谢!

1个回答

16
当你使用 Dataset.map(map_func) 时,TensorFlow会为函数 map_func 中创建的所有操作定义一个子图,并安排在与其余图形相同的会话中高效执行。几乎不需要在 map_func 中创建 tf.Graphtf.Session:如果您的解析函数由 TensorFlow 操作组成,则这些操作可以直接嵌入定义输入管道的图中。
使用 tf.data 的修改后代码将如下所示:
import tensorflow as tf 
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio

def _some_audio_preprocessing_func(filename):
    wav_loader = tf.read_file(filename)
    return contrib_audio.decode_wav(wav_loader, desired_channels=1)

dataset = tf.data.Dataset.list_files('*.wav')
dataset = dataset.map(_some_preprocessing_func)

如果您的map_func包含非TensorFlow操作,并且您希望将其应用于每个元素,则应在tf.py_func()(或Dataset.from_generator(),如果数据生成过程是由Python逻辑定义的)中对它们进行包装。主要性能影响是任何在tf.py_func()中运行的代码都受全局解释器锁的影响,因此我通常建议尝试为任何重要的性能内容寻找本机的TensorFlow实现。


嗨。我有一个关于这个问题的后续问题。如果我使用这个数据集来训练模型,那么在训练之后,如何保存它,以便TensorFlow操作从**_some_audio_preprocessing_func**也包含在最终模型中?谢谢。 - Anuj

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接