将tf.data.Dataset.from_generator并行化

42

我有一个不平凡的输入管道,from_generator非常适合它...

dataset = tf.data.Dataset.from_generator(complex_img_label_generator,
                                        (tf.int32, tf.string))
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

complex_img_label_generator动态生成图像,并返回一个表示(H,W,3)图像和简单的string标签的numpy数组。这个处理过程不是像从文件中读取和tf.image操作那样可以表示的。

我的问题是如何并行化生成器?我如何让N个这样的生成器在自己的线程中运行。

一种想法是使用dataset.mapnum_parallel_calls来处理线程;但是map操作是针对张量的。。。另一个想法是创建多个生成器,每个都有自己的prefetch,然后以某种方式将它们连接起来,但我看不出如何连接N个生成器流?

有任何我可以参考的典型例子吗?

3个回答

29

原来我可以使用Dataset.map,只需让生成器非常轻量级(仅生成元数据),然后将实际的重要工作移动到无状态函数中。这样,我就可以使用.mappy_func并行处理重要的工作部分。

虽然可行,但感觉有些笨拙...如果能够在from_generator中添加num_parallel_calls就好了:)

def pure_numpy_and_pil_complex_calculation(metadata, label):
  # some complex pil and numpy work nothing to do with tf
  ...

dataset = tf.data.Dataset.from_generator(lightweight_generator,
                                         output_types=(tf.string,   # metadata
                                                       tf.string))  # label

def wrapped_complex_calulation(metadata, label):
  return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
                    inp = (metadata, label),
                    Tout = (tf.uint8,    # (H,W,3) img
                            tf.string))  # label
dataset = dataset.map(wrapped_complex_calulation,
                      num_parallel_calls=8)

dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

5
请注意,仅使用tf.py_func()实现并行化可能不会提高速度,请参考此答案 - mikkola
好的观点。虽然从经验上来说,我可以说这个改进大大提高了速度。 - mat kelcey
5
自从你的提问以后,TensorFlow在from_generator中是否添加了num_parallel_calls参数? - Rylan Schaeffer
2
@mikkola 如果不能加速事情,还有其他建议吗?谢谢 - crafet
你知道为什么会获得“巨大加速”吗?这是否意味着尽管@mikkola指出的那点,你的代码实际上正在并行运行? - hipoglucido
显示剩余4条评论

10

我正在为tf.data.Dataset编写一个from_indexablehttps://github.com/tensorflow/tensorflow/issues/14448

from_indexable的优点是可以并行处理,而Python生成器不能并行处理。

from_indexable函数创建一个tf.data.range,将可索引对象包装在通用的tf.py_func中,并调用map函数。

对于那些现在想使用from_indexable的人,这里是库代码:

import tensorflow as tf
import numpy as np

from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import nest

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args):
            nonlocal output_shapes

            flat_output_types = nest.flatten(output_types)
            flat_values = tf.py_func(
                func, 
                inp=args, 
                Tout=flat_output_types,
                stateful=stateful, name=name
            )
            if output_shapes is not None:
                # I am not sure if this is nessesary
                output_shapes = nest.map_structure_up_to(
                    output_types, tensor_shape.as_shape, output_shapes)
                flattened_shapes = nest.flatten_up_to(output_types, output_shapes)
                for ret_t, shape in zip(flat_values, flattened_shapes):
                    ret_t.set_shape(shape)
            return nest.pack_sequence_as(output_types, flat_values)
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)

这里有一个例子(注意:from_indexable有一个num_parallel_calls参数)

class PyDataSet:
    def __len__(self):
        return 20

    def __getitem__(self, item):
        return np.random.normal(size=(item+1, 10))

ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10])
it = ds.make_one_shot_iterator()
entry = it.get_next()
with tf.Session() as sess:
    print(sess.run(entry).shape)
    print(sess.run(entry).shape)

更新 2018年6月10日: 自https://github.com/tensorflow/tensorflow/pull/15121合并后,from_indexable的代码变得更简单:

import tensorflow as tf

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args, **kwargs):
            return tf.contrib.framework.py_func(
                func=func, 
                args=args, kwargs=kwargs, 
                output_types=output_types, output_shapes=output_shapes, 
                stateful=stateful, name=name
            )
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)

2
不幸的是,由于tf2没有contrib,而py_func已被py_function取代,因此它无法经受时间的考验,后者缺少output_shapes、args、kwargs和stateful。最后,py_function的输出返回未知形状,无法在图形内使用。 - Anton
确实,tf 2.x不再有contrib,但您始终可以使用set_shape函数在函数中设置张量的形状。您可以在文档中看到一个示例:https://www.tensorflow.org/guide/data#applying_arbitrary_python_logic - Zaccharie Ramzi

7

将在generator中完成的工作最小化,并使用map并行处理昂贵的操作是明智的。

另外,您可以使用parallel_interleave“连接”多个生成器,如下所示:

def generator(n):
  # 返回第 n 个生成器函数
def dataset(n): return tf.data.Dataset.from_generator(generator(n))
ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))
# 其中 N 是使用的生成器数量

1
你的代码不是有效的Python代码,而且你一开始也没有定义ds - Merlin1896
2
我真的很喜欢这个。但是 generator(n) 应该返回第 n 个生成器,这里的 n 是一个张量。如何获得第 n 个生成器? - Derk
1
现在您可以向 from_generator 提供 args:https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator - Derk

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接