如何使用FASTAI库更改多标签分类预测的阈值

4

我有一个多标签数据集,我正在使用Python的fast-ai库对其进行训练,并使用精度函数作为指标,例如:

def accuracy_multi1(inp, targ, thresh=0.5, sigmoid=True):
    "Compute accuracy when 'inp' and 'targ' are the same size"
    if sigmoid: inp=inp.sigmoid()
    return ((inp>thresh) == targ.bool()).float().mean()

我的学习者就像:

learn = cnn_learner(dls, resnet50, metrics=partial(accuracy_multi1,thresh=0.1))
learn.fine_tune(2,base_lr=3e-2,freeze_epochs=2)

训练模型后,我想预测一张图像,并考虑使用的阈值作为参数,但是方法learn.predict('img.jpg')只考虑默认的thres=0.5。在以下示例中,我的预测应该对“红色”、“衬衫”和“鞋子”返回True,因为它们的概率超过了0.1(但是“鞋子”的概率小于0.5,所以不被视为True):

def printclasses(prediction,classes):
    print('Prediction:',prediction[0])
    for i in range(len(classes)):
        print(classes[i],':',bool(prediction[1][i]),'|',float(prediction[2][i]))

printclasses(learn.predict('rose.jpg'),dls.vocab)

输出:

Prediction: ['red', 'shirt']
black : False | 0.007274294272065163
blue : False | 0.0019288889598101377
brown : False | 0.005750810727477074
dress : False | 0.0028723080176860094
green : False | 0.005523672327399254
hoodie : False | 0.1325301229953766
pants : False | 0.009496113285422325
pink : False | 0.0037188702262938023
red : True | 0.9839697480201721
shirt : True | 0.5762518644332886
shoes : False | 0.2752271890640259
shorts : False | 0.0020902694668620825
silver : False | 0.0009014935349114239
skirt : False | 0.0030087409541010857
suit : False | 0.0006510693347081542
white : False | 0.001247694599442184
yellow : False | 0.0015280473744496703

在进行图像预测时,是否有一种方法可以强制设置阈值,类似于以下内容:

learn.predict('img.jpg',thresh=0.1)
1个回答

1
我遇到了同样的问题。我仍然对更好的解决方案感兴趣,但是由于 `accuracy_mult` 似乎只在训练过程中提供用户友好的模型评估(并不参与预测),因此我为我的数据创建了一个解决方法。
基本思路是取出实际预测的张量(这是 `predict()` 函数返回的三元组中的第三个条目),应用阈值并从词汇表中获取相应的标签。
def predict_labels(x, model, thresh=0.5):
  '''
  function to predict multi-labels in text (x)

  arguments:
  ----------
  x: the text to predict
  model: the trained learner
  thresh: thresh to indicate which labels should be included, fastai default is 0.5

  return:
  -------
  (str) predictions separated by blankspace
  '''

  # getting categories according to threshold
  preds = model.predict(x)[2] > thresh
  labels = model.dls.multi_categorize.vocab[preds]

  return ' '.join(labels) 


1
这正是我一直在使用的变通方法,直到我得到一个明确的解决方案! - brenodacosta

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接