TensorArray 和 while_loop 在 TensorFlow 中如何协同工作?

9
我正在尝试为TensorArray和while_loop的组合编写一个非常简单的示例:
# 1000 sequence in the length of 100
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix")
matrix_rows = tf.shape(matrix)[0]
ta = tf.TensorArray(tf.float32, size=matrix_rows)
ta = ta.unstack(matrix)

init_state = (0, ta)
condition = lambda i, _: i < n
body = lambda i, ta: (i + 1, ta.write(i,ta.read(i)*2))

# run the graph
with tf.Session() as sess:
    (n, ta_final) = sess.run(tf.while_loop(condition, body, init_state),feed_dict={matrix: tf.ones(tf.float32, shape=(100,1000))})
    print (ta_final.stack())

但是我遇到了以下错误:
ValueError: Tensor("while/LoopCond:0", shape=(), dtype=bool) must be from the same graph as Tensor("Merge:0", shape=(), dtype=float32).

有人知道问题是什么吗?


要获取最终的 TensorArray,您需要运行 session.run(ta.stack()),而不是直接运行循环,因为您无法运行 session.run(TensorArray) - sirfz
抱歉,我不明白你的意思。你能否写出正确的形式? - E.Asgari
1个回答

10

在你的代码中有几个需要指出的问题。首先,你不需要将矩阵解压成TensorArray才能在循环内部使用它,可以安全地在循环体内引用矩阵Tensor并使用matrix[i]进行索引。另一个问题是矩阵(tf.int32)和TensorArray(tf.float32)之间的数据类型不同,在你的代码中你正在将矩阵整数乘以2并将结果写入数组,因此它也应该是int32。最后,当你想要读取循环的最终结果时,正确的操作是TensorArray.stack(),这是你需要在session.run调用中运行的内容。

下面是一个有效的示例:

import numpy as np
import tensorflow as tf    

# 1000 sequence in the length of 100
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix")
matrix_rows = tf.shape(matrix)[0]
ta = tf.TensorArray(dtype=tf.int32, size=matrix_rows)

init_state = (0, ta)
condition = lambda i, _: i < matrix_rows
body = lambda i, ta: (i + 1, ta.write(i, matrix[i] * 2))
n, ta_final = tf.while_loop(condition, body, init_state)
# get the final result
ta_final_result = ta_final.stack()

# run the graph
with tf.Session() as sess:
    # print the output of ta_final_result
    print sess.run(ta_final_result, feed_dict={matrix: np.ones(shape=(100,1000), dtype=np.int32)}) 

在这种情况下,我可以指定输入而不使用feed dictionary,就像如果我在计算图之间使用它,我如何指定张量数组取决于某个张量? - Rahul
@Rahul 如果我理解你的问题,matrix 可以是任何类型的 Tensor,不一定是一个 placeholder - sirfz
在最后一行,我将np.ones(tf.int32, shape=(100,1000))改为np.ones(dtype=np.int32, shape=(100,1000))以便在Python 3上运行此代码。 - Saeid BK

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接