如何在TensorFlow中对一个批次进行切片并对每个切片应用操作

4
我是一名TensorFlow的初学者,我正在尝试实现一个以batch为输入的函数。它必须将这个batch切成几个部分,在其中应用一些操作,然后将它们连接起来构建一个新的张量并返回。通过我的阅读,我发现了一些实现的函数,如input_slice_producer和batch_join,但我没有使用它们。我附上了我找到的解决方案,但它有点慢,不太合适,并且不能检测批次的当前大小。有人知道更好的方法吗?
def model(x):

    W_1 = tf.Variable(tf.random_normal([6,1]),name="W_1")
    x_size = x.get_shape().as_list()[0]
    # x is a batch of bigger input of shape [None,6], so I couldn't 
    # get the proper size of the batch when feeding it 
    if x_size == None:
        x_size= batch_size
    #intialize the y_res
    dummy_x = tf.slice(x,[0,0],[1,6])
    result = tf.reduce_sum(tf.mul(dummy_x,W_1))
    y_res = tf.zeros([1], tf.float32)
    y_res = result
    #go throw all slices and concatenate them to get result
    for i in range(1,x_size): 
        dummy_x = tf.slice(x,[i,0],[1,6])
        result = tf.reduce_sum(tf.mul(dummy_x,W_1))
        y_res = tf.concat(0, [y_res, result])

    return y_res

你可以将所有的y_res累积到一个列表中,并一次性创建最终张量,否则你会一遍又一遍地复制数据。 - fabrizioM
谢谢,但我认为将y_res变成列表不会有太大改变,因为我将在每次迭代中使用“append”。这与使用函数“concat”并没有太大区别,因为它基本上执行的是与“append”相同的操作。我的挑战是:自动获取批处理的当前大小并在函数内部进行多线程计算。 - chikhawi9
1个回答

6

TensorFlow函数tf.map_fn(fn, elems)允许您将一个函数(fn)应用于张量(elems)的每个切片。例如,您可以按以下方式表示程序:

def model(x):
    W_1 = tf.Variable(tf.random_normal([6, 1]), name="W_1")

    def fn(x_slice):
        return tf.reduce_sum(x_slice, W_1)

    return tf.map_fn(fn, x)

使用tf.mul()运算符上的广播和NumPy广播语义,以及axis参数在tf.reduce_sum()上更简洁地实现操作也是可能的。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接