如何在Tensorflow中通过tf.cond将参数传递给函数?

12

我有以下简单的占位符:

x = tf.placeholder(tf.float32, shape=[1])
y = tf.placeholder(tf.float32, shape=[1])
z = tf.placeholder(tf.float32, shape=[1])

有两个函数fn1fn2,定义如下:

def fn1(a, b): 
    return tf.mul(a, b)
def fn2(a, b): 
    return tf.add(a, b)

现在我想根据预测条件计算结果:

pred = tf.placeholder(tf.bool, shape=[1])
result = tf.cond(pred, fn1(x,y), fn2(y,z))

但它会报错,提示fn1和fn2必须可调用

我该如何编写fn1fn2,使它们能在运行时接收参数?我想要调用以下内容:

sess.run(result, feed_dict={x:1,y:2,z:3,pred:True})
2个回答

23

你可以使用lambda将参数传递给函数,代码如下。

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)

def fn1(a, b):
  return tf.mul(a, b)

def fn2(a, b):
  return tf.add(a, b)

pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

然后您可以按照以下方式进行调用:

with tf.Session() as sess:
  print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
  # The result is 2.0

4
最简单的方法是在调用时定义您的函数:
result = tf.cond(pred, lambda: tf.mul(a, b), lambda: tf.add(a, b))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接