在TensorFlow中,与Python中的enumerate等效的函数是tf.map_fn中使用索引的功能。

4

我有一个张量,使用tf.map_fn逐行处理。现在,我想将索引作为参数包含在传递给tf.map_fn的函数中。在numpy中,我可以使用enumerate获取该信息,并将其传递给lambda函数。以下是一个numpy示例,其中我将0添加到第一行,将1添加到第二行,以此类推:

a = np.array([[2, 1], [4, 2], [-1, 2]])

def some_function(x, i):
    return x + i

res = map(lambda (i, row): some_function(row, i), enumerate(a))
print(res)

> [array([2, 1]), array([5, 3]), array([1, 4])]

尽管我没有找到 Tensorflow 中与 enumerate 相对应的函数,但我不知道如何在 Tensorflow 中实现相同的结果。有没有人知道要使用什么来使其在 Tensorflow 中工作呢?这是一个示例代码,在其中每行加 1:

import tensorflow as tf

a = tf.constant([[2, 1], [4, 2], [-1, 2]])

with tf.Session() as sess:
    res = tf.map_fn(lambda row: some_function(row, 1), a)
    print(res.eval())

> [[3 2]
   [5 3]
   [0 3]]

感谢任何能够帮助我解决这个问题的人。

1个回答

9

tf.map_fn()可以有许多输入/输出。因此,您可以使用tf.range()来构建行索引的张量,并将其与之一起使用:

import tensorflow as tf

def some_function(x, i):
    return x + i

a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)

res, _ = tf.map_fn(lambda x: (some_function(x[0], x[1]), x[1]), 
                   (a, a_rows), dtype=(tf.int32, tf.int32))

with tf.Session() as sess:
    print(res.eval())
    # [[2 1]
    #  [5 3]
    #  [1 4]]

注意:在许多情况下,“逐行处理矩阵”可以一次完成,例如通过广播,而不是使用循环:
import tensorflow as tf

a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)
res = a + a_rows

with tf.Session() as sess:
    print(res.eval())
    # [[2 1]
    #  [5 3]
    #  [1 4]]

1
非常好,谢谢!至于那个备注,不是我想要的,但还是谢谢你分享这个技巧。我需要索引来从数据库中检索一个值。 - kluu

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接