我开始通过将
thresholds = np.arange(0,1,0.1)
转换为更智能、二分查找的方式来改进解决方案。
然后我意识到,在工作了2个小时后,获取
所有准确性要比仅查找最大值便宜得多!(是的,这完全是违反直觉的。)
我在下面写了很多注释来解释我的代码。随意删除所有这些以使代码更易读。
import numpy as np
def ROC_curve_data(y_true, y_score):
y_true = np.asarray(y_true, dtype=np.bool_)
y_score = np.asarray(y_score, dtype=np.float_)
assert(y_score.size == y_true.size)
order = np.argsort(y_score)
y_true = y_true[order]
thresholds = np.insert(y_score[order],0,0)
TP = [sum(y_true)]
FP = [sum(~y_true)]
TN = [0]
FN = [0]
for i in range(1, thresholds.size) :
TP.append(TP[-1] - int(y_true[i-1]))
FN.append(FN[-1] + int(y_true[i-1]))
FP.append(FP[-1] - int(~y_true[i-1]))
TN.append(TN[-1] + int(~y_true[i-1]))
TP = np.asarray(TP, dtype=np.int_)
FP = np.asarray(FP, dtype=np.int_)
TN = np.asarray(TN, dtype=np.int_)
FN = np.asarray(FN, dtype=np.int_)
accuracy = (TP + TN) / (TP + FP + TN + FN)
sensitivity = TP / (TP + FN)
specificity = TN / (FP + TN)
return((thresholds, TP, FP, TN, FN))
整个过程只是一个单一的循环,算法非常简单。
实际上,这个愚蠢简单的函数比我之前提出的解决方案(计算
thresholds = np.arange(0,1,0.1)
的准确性)快10倍,比我以前的聪明二分算法快30倍...
然后你可以轻松地计算
任何你想要的KPI,例如:
def max_accuracy(thresholds, TP, FP, TN, FN) :
accuracy = (TP + TN) / (TP + FP + TN + FN)
return(max(accuracy))
def max_min_sensitivity_specificity(thresholds, TP, FP, TN, FN) :
sensitivity = TP / (TP + FN)
specificity = TN / (FP + TN)
return(max(np.minimum(sensitivity, specificity)))
如果您想进行测试:
如果您想进行测试:
y_score = np.random.uniform(size = 100)
y_true = [np.random.binomial(1, p) for p in y_score]
data = ROC_curve_data(y_true, y_score)
%matplotlib inline
import matplotlib.pyplot as plt
plt.step(data[0], data[1])
plt.step(data[0], data[2])
plt.step(data[0], data[3])
plt.step(data[0], data[4])
plt.show()
print("Max accuracy is", max_accuracy(*data))
print("Max of Min(Sensitivity, Specificity) is", max_min_sensitivity_specificity(*data))
祝愉快 ;)
accuracy = np.array(accuracy)
应该改为accuracy = np.array(accuracies)
或类似的代码 :) - Geeocode