如何将值注入TensorFlow图的中间?

5

请考虑以下代码:

x = tf.placeholder(tf.float32, (), name='x')
z = x + tf.constant(5.0)
y = tf.mul(z, tf.constant(0.5))

with tf.Session() as sess:
    print(sess.run(y, feed_dict={x: 30}))

生成的图表为x -> z -> y。有时我想从x计算y,但有时我有z作为起点,并希望将此值注入到图表中。因此,z需要像部分占位符一样运行。我该如何做?
(对于任何有兴趣了解我为什么需要这个的人。我正在使用自动编码器网络,该网络观察图像x,生成中间压缩表示z,然后计算图像y的重构。当我为z注入不同的值时,我想看到网络重构的内容。)
2个回答

5

以下是在默认情况下使用占位符的方式:

x = tf.placeholder(tf.float32, (), name='x')
# z is a placeholder with default value
z = tf.placeholder_with_default(x+tf.constant(5.0), (), name='z')
y = tf.mul(z, tf.constant(0.5))

with tf.Session() as sess:
    # and feed the z in
    print(sess.run(y, feed_dict={z: 5}))

我真是太傻了。


3

我不能在你的帖子中进行评论,@iramusa,所以我会给出一个答案。你不需要使用placeholder_with_default。你可以直接向任何节点输入值:

import tensorflow as tf

x = tf.placeholder(tf.float32,(), name='x')
z = x + tf.constant(5.0)
y = z*tf.constant(0.5)

with tf.Session() as sess:
    print(sess.run(y, feed_dict={x: 2}))  # get 3.5
    print(sess.run(y, feed_dict={z: 5}))  # get 2.5
    print(sess.run(y, feed_dict={y: 5}))  # get 5

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接