在 while 循环中更改张量的单个值

3

我如何在while循环内更改张量的单个值?我知道可以使用tf.scatter_update(variable, index, value)来操作tf.Variable的单个值,但在循环内无法访问变量。有没有一种方法/解决方法来在while循环内操作给定的Tensor值。

以下是我的当前代码:

my_variable = tf.Variable()

def body(i, my_variable):
    [...]
    return tf.add(i, 1), tf.scatter_update(my_variable, [index], value)


loop = tf.while_loop(lambda i, _: tf.less(i, 5), body, [0, my_variable])

请查看以下链接:https://stackoverflow.com/questions/51419333/modify-a-tensorflow-variable-inside-a-loop - TassosK
@TassosK while_loop将我的tf.Variable转换为tf.Tensor,因此当我按照那篇文章中所说的去做时,会出现TypeError: 'Tensor' object does not support item assignment - Ian Rehwinkel
1个回答

3

这篇文章的启发,您可以使用稀疏张量来存储要分配的值的增量,然后使用加法来“设置”该值。例如,像这样(我假设这里有一些形状/值,但应该很容易将其推广到更高秩的张量):

import tensorflow as tf

my_variable = tf.Variable(tf.ones([5]))

def body(i, v):
    index = i
    new_value = 3.0
    delta_value = new_value - v[index:index+1]
    delta = tf.SparseTensor([[index]], delta_value, (5,))
    v_updated = v + tf.sparse_tensor_to_dense(delta)
    return tf.add(i, 1), v_updated


_, updated = tf.while_loop(lambda i, _: tf.less(i, 5), body, [0, my_variable])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(my_variable))
    print(sess.run(updated))

这会打印出来

[1. 1. 1. 1. 1.]
[3. 3. 3. 3. 3.]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接