TensorFlow的while_loop将变量转换成常量?

5

我正在尝试在嵌套的while_loop()中更新一个二维张量。然而,当将变量传递给第二个循环时,我无法使用tf.assign()进行更新,因为它会抛出以下错误:

ValueError: Sliced assignment is only supported for variables

如果我在while循环外创建变量并仅在第一个循环中使用它,那么它的运行很好。

我如何在第二个while循环中修改我2D tf变量?
(我使用Python 2.7和TensorFlow 1.2)

我的代码:

import tensorflow as tf
import numpy as np

tf.reset_default_graph()

BATCH_SIZE = 10
LENGTH_MAX_OUTPUT = 31

it_batch_nr = tf.constant(0)
it_row_nr = tf.Variable(0, dtype=tf.int32)
it_col_nr = tf.constant(0)
cost = tf.constant(0)

it_batch_end = lambda it_batch_nr, cost: tf.less(it_batch_nr, BATCH_SIZE)
it_row_end = lambda it_row_nr, cost_matrix: tf.less(it_row_nr, LENGTH_MAX_OUTPUT+1)

def iterate_batch(it_batch_nr, cost):
    cost_matrix = tf.Variable(np.ones((LENGTH_MAX_OUTPUT+1, LENGTH_MAX_OUTPUT+1)), dtype=tf.float32)
    it_rows, cost_matrix = tf.while_loop(it_row_end, iterate_row, [it_row_nr, cost_matrix])
    cost = cost_matrix[0,0] # IS 1.0, SHOULD BE 100.0
    return tf.add(it_batch_nr,1), cost

def iterate_row(it_row_nr, cost_matrix):
    # THIS THROWS AN ERROR:
    cost_matrix[0,0].assign(100.0)
    return tf.add(it_row_nr,1), cost_matrix

it_batch = tf.while_loop(it_batch_end, iterate_batch, [it_batch_nr, cost])

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
out = sess.run(it_batch)
print(out)

请将您的解决方案作为答案而不是编辑问题提交,您可以在修订版3中找到您的解决方案,只需将其复制并粘贴到答案中即可。 - Petter Friberg
2个回答

2

tf.Variable对象不能在while循环中用作循环变量,因为循环变量的实现方式不同。

因此,要么在循环外创建您的变量,并在每次迭代中使用tf.assign自己更新它,要么手动跟踪更新,就像处理循环变量一样(通过从循环lambda返回其更新值,并在您的情况下使用内部循环的值作为外部循环的新值)。


谢谢,这样就解决了错误。但是,如果我将cost_matrix变量声明移到循环外,并从循环变量中删除它,cost仍然为1.0,而我期望它为100.0。 - Deruijter
显然,我没有意识到操作必须由输出变量使用才能运行。我猜这与TF解释和优化代码的方式有关。我使用了一个hacky的解决方案来克服这个问题,但我认为有更好的方法来解决这个问题(请参见问题中的编辑)。将会就此提出一个新问题。 - Deruijter
更新:可以使用 tf.control_dependencies() 来触发要运行的操作(问题中的工作解决方案已更新)。 - Deruijter

1

在@AlexandrePassos的帮助下,我成功地使这个工作起来,通过将变量放在while_loop之外。然而,我还必须使用tf.control_dependencies()来强制执行命令(因为操作不直接用于循环变量)。现在循环看起来像这样:

cost_matrix = tf.Variable(np.ones((LENGTH_MAX_OUTPUT+1, LENGTH_MAX_OUTPUT+1)), dtype=tf.float32)

def iterate_batch(it_batch_nr, cost):
    it_rows = tf.while_loop(it_row_end, iterate_row, [it_row_nr])
    with tf.control_dependencies([it_rows]):
        cost = cost_matrix[0,0] 
        return tf.add(it_batch_nr,1), cost

def iterate_row(it_row_nr):
    a = tf.assign(cost_matrix[0,0], 100.0)
    with tf.control_dependencies([a]):
        return tf.add(it_row_nr,1)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接