如何使用Tensorflow和scikit-learn绘制ROC曲线?

7
我正在尝试从tensorflow提供的CIFAR-10示例的修改版本中绘制ROC曲线。现在它是针对2个类而不是10个类。
网络的输出称为logits,其形式如下:
[[-2.57313061 2.57966399] [ 0.04221377 -0.04033273] [-1.42880082 1.43337202] [-2.7692945 2.78173304] [-2.48195744 2.49331546] [ 2.0941515 -2.10268974] [-3.51670194 3.53267646] [-2.74760485 2.75617766] ...]
首先,这些logits实际上代表什么?网络中的最后一层是具有WX+b形式的“softmax linear”。
该模型能够通过调用计算准确度。
top_k_op = tf.nn.in_top_k(logits, labels, 1)

然后一旦图表被初始化:
predictions = sess.run([top_k_op])
predictions_int = np.array(predictions).astype(int)
true_count += np.sum(predictions) 
...
precision = true_count / total_sample_count

这个工作得很好。

但现在我该如何从这里绘制一个ROC曲线呢?

我一直在尝试使用"sklearn.metrics.roc_curve()"函数(http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#sklearn.metrics.roc_curve),但我不知道该使用什么作为我的"y_score"参数。

非常感谢任何帮助!


请查看此链接(http://stackoverflow.com/questions/35811446/classification-accuracy-after-recall-and-precision/37275638#37275638),其中有一个计算和绘制ROC曲线的代码。 - Nikolas Rieble
2个回答

1
“y_score”应该是一个数组,对应于每个样本被分类为正例的概率(如果在您的y_true数组中标记为1,则为正例)。
实际上,如果您的网络使用Softmax作为最后一层,则模型应输出此实例每个类别的概率。但是,您提供的数据不符合此格式。我检查了示例代码:https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/image/cifar10/cifar10.py,似乎使用了称为softmax_linear的层。我对这个示例知之甚少,但我猜测您应该使用类似于Logistic函数的东西来处理输出,以将其转换为概率。
然后,只需将其与真实标签“y_true”一起提供给scikit-learn函数即可:
y_score = np.array(output)[:,1]
roc_curve(y_true, y_score)

我也遇到了同样的情况...如果模型的输出值太大,它会使逻辑函数饱和。我该怎么办? - Helder

1
import tensorflow as tf
tp = [] # the true positive rate list
fp = [] # the false positive rate list
total = len(fp)
writer = tf.train.SummaryWriter("/tmp/tensorboard_roc")
for idx in range(total):
    summt = tf.Summary()
    summt.value.add(tag="roc", simple_value = tp[idx])
    writer.add_summary (summt, tp[idx] * 100) #act as global_step
    writer.flush ()

然后启动一个TensorBoard:
tensorboard --logdir=/tmp/tensorboard_roc

tensorboard_roc

关于详情和代码,您可以访问我的博客:http://blog.csdn.net/mao_feng/article/details/54731098


我在我的模型中使用了这段代码,但是在tensorboard上我只看到了一条从(0,0)直线。 - Kyrol
是的,即使我从(0,0)到(1,1)也看到了一条对角线。这不起作用。 - London guy

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接