使用logits的TensorFlow稀疏分类交叉熵

6
我是一名初学者程序员,尝试按照这个指南进行操作。然而,我遇到了一个问题。指南中说要将损失函数定义为:
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

这让我收到以下错误信息:
“sparse_categorical_crossentropy()” 得到了一个意外的关键字参数 'from_logits'。
我的理解是,这个意味着函数中没有指定 from_logits 参数,这与文档支持的事实相符。tf.keras.losses.sparse_categorical_crossentropy() 只有两个可能的输入。
是否需要指定 logits 被使用的方法或者是否有必要这样做?

1
您可能需要验证您是否拥有正确的TensorFlow版本。 - Mateen Ulhaq
2个回答

8
我在学习教程时遇到了同样的问题。我把代码改成了...。
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

to

def loss(labels, logits):
    return tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)

这样做解决了问题,而无需安装tf-nightly。

2
from_logits 参数在 Tensorflow 1.13 中引入。
你可以通过这些 URL 来比较 1.12 和 1.13 版本:
https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/losses.py
https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/keras/losses.py

在撰写本文时,1.13版本尚未发布。这就是为什么教程从这行代码开始的原因。

!pip install -q tf-nightly

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接