我正在 Kaggle 上参加一个竞赛,评估指标被定义为
该竞赛的评估基于在不同交集联合(IoU)阈值下的平均精度。所提出的一组物体像素与一组真实物体像素的 IoU 计算如下:
IoU(A,B)=(A∩B)/(A∪B)
该度量标准涵盖了一系列的IoU阈值,每个阈值都会计算平均精度值。阈值范围从0.5到0.95,步长为0.05:
(0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95)
。换句话说,在0.5的阈值下,如果预测对象与真实对象的交并比大于0.5,则被视为“命中”。在每个阈值t下,基于真正例(TP)
、假负例(FN)
和假正例(FP)
的数量,计算出一个精度值,用于比较预测对象与所有真实对象的结果。 TP(t)/TP(t)+FP(t)+FN(t).
当一个预测对象与一个IoU阈值以上的基准真实对象匹配时,就会计算为真正例。假阳性表示预测对象没有关联的基准真实对象。假阴性表示基准真实对象没有相关的预测对象。然后,单个图像的平均精度是通过计算每个IoU阈值处的上述精度值的平均值来计算的:
(1/|thresholds|)*∑tTP(t)/TP(t)+FP(t)+FN(t)
现在,我已经使用纯numpy编写了这个函数,因为在那里编码更容易,并且我已经用tf.py_fucn()
修饰它,以便与Keras一起使用。以下是示例代码:
def iou_metric(y_true_in, y_pred_in, fix_zero=False):
labels = y_true_in
y_pred = y_pred_in
true_objects = 2
pred_objects = 2
if fix_zero:
if np.sum(y_true_in) == 0:
return 1 if np.sum(y_pred_in) == 0 else 0
intersection = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))[0]
# Compute areas (needed for finding the union between all objects)
area_true = np.histogram(labels, bins = true_objects)[0]
area_pred = np.histogram(y_pred, bins = pred_objects)[0]
area_true = np.expand_dims(area_true, -1)
area_pred = np.expand_dims(area_pred, 0)
# Compute union
union = area_true + area_pred - intersection
# Exclude background from the analysis
intersection = intersection[1:,1:]
union = union[1:,1:]
union[union == 0] = 1e-9
# Compute the intersection over union
iou = intersection / union
# Precision helper function
def precision_at(threshold, iou):
matches = iou > threshold
true_positives = np.sum(matches, axis=1) == 1 # Correct objects
false_positives = np.sum(matches, axis=0) == 0 # Missed objects
false_negatives = np.sum(matches, axis=1) == 0 # Extra objects
tp, fp, fn = np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)
return tp, fp, fn
# Loop over IoU thresholds
prec = []
for t in np.arange(0.5, 1.0, 0.05):
tp, fp, fn = precision_at(t, iou)
if (tp + fp + fn) > 0:
p = tp / (tp + fp + fn)
else:
p = 0
prec.append(p)
return np.mean(prec)
我试图将它转换为纯的tf
函数,但由于我无法弄清楚控制依赖关系
的工作原理,所以无法完成。有人可以帮助我吗?
pyfunc
吗? - Jonas Adlertf.py_func()
。 - enterMLtf.metrics.mean_iou
。 - Zaccharie Ramzi