我有一个张量,使用tf.map_fn
逐行处理。现在,我想将索引作为参数包含在传递给tf.map_fn
的函数中。在numpy中,我可以使用enumerate
获取该信息,并将其传递给lambda函数。以下是一个numpy示例,其中我将0添加到第一行,将1添加到第二行,以此类推:
a = np.array([[2, 1], [4, 2], [-1, 2]])
def some_function(x, i):
return x + i
res = map(lambda (i, row): some_function(row, i), enumerate(a))
print(res)
> [array([2, 1]), array([5, 3]), array([1, 4])]
尽管我没有找到 Tensorflow 中与 enumerate
相对应的函数,但我不知道如何在 Tensorflow 中实现相同的结果。有没有人知道要使用什么来使其在 Tensorflow 中工作呢?这是一个示例代码,在其中每行加 1:
import tensorflow as tf
a = tf.constant([[2, 1], [4, 2], [-1, 2]])
with tf.Session() as sess:
res = tf.map_fn(lambda row: some_function(row, 1), a)
print(res.eval())
> [[3 2]
[5 3]
[0 3]]
感谢任何能够帮助我解决这个问题的人。