用户 @meTchaikovsky 解释了 tf.reduce_mean
的一般情况。在你们两个的情况下,tf.reduce_mean
就像任何平均值计算器一样工作,即你不是沿着张量的任何特定轴取平均值,而是将张量中元素的总和除以元素数量。
让我们解码两种情况中到底发生了什么。对于这两种情况,假设 batch_size = 2
和 num_classes = 5
,意味着每个批次有两个示例。现在对于第一种情况,tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y)
返回一个形状为 (2,)
的数组。
>>import numpy as np
>>import tensorflow as tf
>>sess= tf.InteractiveSession()
>>batch_size = 2
>>num_classes = 5
>>logits = np.random.rand(batch_size,num_classes)
>>print(logits)
[[0.94108451 0.68186329 0.04000461 0.25996487 0.50391948]
[0.22781201 0.32305269 0.93359371 0.22599208 0.05942905]]
>>labels = np.array([[1,0,0,0,0],[0,1,0,0,0]])
>>print(labels)
[[1 0 0 0 0]
[0 1 0 0 0]]
>>logits_ = tf.placeholder(dtype=tf.float32,shape=(batch_size,num_classes))
>>Y_ = tf.placeholder(dtype=tf.int32,shape=(batch_size,num_classes))
>>loss_op = tf.nn.softmax_cross_entropy_with_logits(logits=logits_, labels=Y_)
>>loss_per_example = sess.run(loss_op,feed_dict={Y_:labels,logits_:logits})
>>print(loss_per_example)
array([1.2028817, 1.6912657], dtype=float32)
您可以看到loss_per_example
的形状为(2,)
。如果我们对此变量取平均值,那么我们就可以近似计算整个批次的平均损失。因此,我们进行如下计算:
>>loss_per_example_holder = tf.placeholder(dtype=tf.float32,shape=(batch_size))
>>final_loss_per_batch = tf.reduce_mean(loss_per_example_holder)
>>final_loss = sess.run(final_loss_per_batch,feed_dict={loss_per_example_holder:loss_per_example})
>>print(final_loss)
1.4470737
来看你的第二个情况:
>>predictions_holder = tf.placeholder(dtype=tf.float32,shape=(batch_size,num_classes))
>>labels_holder = tf.placeholder(dtype=tf.int32,shape=(batch_size,num_classes))
>>prediction_tf = tf.equal(tf.argmax(predictions_holder, 1), tf.argmax(labels_holder, 1))
>>labels_match = sess.run(prediction_tf,feed_dict={predictions_holder:logits,labels_holder:labels})
>>print(labels_match)
[ True False]
上面的输出是预期的,因为变量
logits
的第一个示例仅表示具有最高激活(
0.9410
)的神经元是零号,与标签相同。现在我们想计算准确性,这意味着我们必须取变量
labels_match
的平均值。
>>labels_match_holder = tf.placeholder(dtype=tf.float32,shape=(batch_size))
>>accuracy_calc = tf.reduce_mean(tf.cast(labels_match_holder, tf.float32))
>>accuracy = sess.run(accuracy_calc, feed_dict={labels_match_holder:labels_match})
>>print(accuracy)
0.5