TensorFlow:检查标量布尔张量是否为True

9
我想使用占位符来控制函数的执行,但是一直收到错误提示“不允许将tf.Tensor用作Python布尔值”。以下是产生此错误的代码:
import tensorflow as tf
def foo(c):
  if c:
    print('This is true')
    #heavy code here
    return 10
  else:
    print('This is false')
    #different code here
    return 0

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()

我将if c修改为if c is not None,但没有成功。如何通过开启和关闭占位符a来控制foo

更新:正如@nessuno和@nemo指出的那样,我们必须使用tf.cond而不是if..else。我的问题的答案是像这样重新设计我的函数:

import tensorflow as tf
def foo(c):
  return tf.cond(c, func1, func2)

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close() 

大约5年前的版本,不适用于TF2.x。 - EngrStudent
3个回答

11
你需要使用 tf.cond 来定义一个条件操作,从而改变张量的流动。
import tensorflow as tf

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = tf.cond(tf.equal(a, tf.constant(True)), lambda: tf.constant(10), lambda: tf.constant(0))
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()
print(res)

10


实际上foo函数非常复杂。我只想通过打开/关闭 a 来更改该函数中的某些操作。我该如何保留foo函数?我怀疑问题来自于 {a: True}if c: 的其中之一。 - Tu Bui
2
你只需要定义两个不同的函数,在条件被评估时执行。唯一的限制是两个函数必须返回相同数量和类型的值。因此,你可以定义自己的函数并使用它们,而不是使用lambda函数。 - nessuno

1
实际执行不是在Python中进行的,而是在您提供的TensorFlow后端中执行计算图。这意味着您想要应用的每个条件和流程控制都必须被表述为计算图中的节点。
对于if条件,有cond操作:
b = tf.cond(c, 
           lambda: tf.constant(10), 
           lambda: tf.constant(0))

0
一个更简单的解决方法是:
In [50]: a = tf.placeholder(tf.bool)                                                                                                                                                                                 

In [51]: is_true = tf.count_nonzero([a])                                                                                                                                                                             

In [52]: sess.run(is_true, {a: True})                                                                                                                                                                                
Out[52]: 1

In [53]: sess.run(is_true, {a: False})
Out[53]: 0

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接