Tensorflow 尝试使用未初始化的值AUC/AUC/auc/false_positives。

9
我正在使用卷积神经网络进行图像分类的训练。由于我的数据集大小有限,我正在使用迁移学习。基本上,我正在使用Google在其重新训练示例中提供的预训练网络(https://www.tensorflow.org/tutorials/image_retraining)。
该模型表现出色,给出了非常好的准确性。但是我的数据集高度不平衡,这意味着准确性不是评估模型性能的最佳指标。
通过研究不同的解决方案,一些人建议更改采样方法或使用的性能指标。我选择后者。
Tensorflow提供了很多指标,包括AUC、精确度、召回率等。
现在,以下是重新训练模型的代码: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py 我正在向add_evaluation_step(result_tensor, ground_truth_tensor)函数添加以下内容:
  with tf.name_scope('AUC'):
    with tf.name_scope('prediction'):
        prediction = tf.argmax(result_tensor, 1)
    with tf.name_scope('AUC'):
        auc_value = tf.metrics.auc(tf.argmax(ground_truth_tensor, 1), prediction, curve='ROC')


  tf.summary.scalar('accuracy', evaluation_step)
  tf.summary.scalar('AUC', auc_value)

但是我得到了这个错误:
Traceback (most recent call last): File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py", line 1135, in tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 44, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py", line 911, in main ground_truth_input: train_ground_truth}) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py", line 767, in run run_metadata_ptr) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py", line 965, in _run feed_dict_string, options, run_metadata) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py", line 1015, in _do_run target_list, options, run_metadata) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py", line 1035, in _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.FailedPreconditionError: 尝试使用未初始化的值 AUC/AUC/auc/false_positives
[[Node: AUC/AUC/auc/false_positives/read = IdentityT=DT_FLOAT, _class=["loc:@AUC/AUC/auc/false_positives"], _device="/job:localhost/replica:0/task:0/cpu:0"]]

操作 'AUC/AUC/auc/false_positives/read' 的原因,定义于: File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py", line 1135, in tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 44, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py", line 874, in main final_tensor, ground_truth_input) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py", line 806, in add_evaluation_step auc_value, update_op = tf.metrics.auc(tf.argmax(ground_truth_tensor, 1), prediction, curve='ROC') File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py", line 555, in auc labels, predictions, thresholds, weights) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py", line 473, in _confusion_matrix_at_thresholds false_p = _create_local('false_positives', shape=[num_thresholds]) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py", line 177, in _create_local validate_shape=validate_shape) File "/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/variables.py", line 226, in

但我不明白为什么会这样,因为我的主要代码是这样的:
init = tf.global_variables_initializer()
sess.run(init)
1个回答

24

试试这个:

init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init)

谢谢,这个有效!但是为什么?tf.local_variables_initializer()是做什么的? - Andreas Storvik Strauman
1
除了使用 tf.Variabletf.get_variable 创建的可训练变量之外,还有一些不可训练的变量,例如全局步骤(global step),它在每个训练步骤后增加一,这些变量都被称为本地变量。如果您想计算AUC,则TensorFlow会隐式地为其创建一些本地变量,因此您也需要对其进行初始化。 - Jie.Zhou
啊,好的。谢谢 :) - Andreas Storvik Strauman

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接