Pytorch:如何为语义分割计算IoU(Jaccard指数)

10

有人能提供一个用pytorch计算语义分割中IoU(交并比)的玩具示例吗?


https://stackoverflow.com/questions/43261072/jaccards-distance-matrix-with-tensorflow 展示了如何在tensorflow中实现它。将其移植到PyTorch应该很容易。 - nemo
3个回答

20
截至2021年,无需自行实现IoU,因为torchmetrics已经内置了它 - 这是链接。 它被命名为torchmetrics.JaccardIndex(之前是torchmetrics.IoU),可以计算你想要的结果。 它适用于PyTorch和PyTorch Lightning,也适用于分布式训练。
从文档中可以看到:

torchmetrics.JaccardIndex(num_classes, ignore_index=None, absent_score=0.0, threshold=0.5, multilabel=False, reduction='elementwise_mean', compute_on_step=None, **kwargs)

计算交并比,或Jaccard指数:

J(A,B) = \frac{|A\cap B|}{|A\cup B|}

其中: AB 是具有相同大小的张量,包含整数类别值。它们可能需要从输入数据进行转换(请参阅下面的描述)。请注意,它与框的IoU不同。

适用于二进制、多类和多标签数据。接受模型输出的概率或预测中的整数类别值。适用于多维预测和目标。

Forward接受以下参数:

  • preds(浮点或长整型张量):(N, ...)(N, C, ...),其中C是类别数
  • target(长整型张量):(N, ...) 如果preds和target具有相同的形状,并且preds是浮点张量,则使用self.threshold参数将其转换为整数标签。这适用于二进制和多标签概率。

如果preds具有额外的维度,例如多类别分数的情况下,我们在dim=1上执行argmax。

官方示例:
from torchmetrics import JaccardIndex
target = torch.randint(0, 2, (10, 25, 25))
pred = torch.tensor(target)
pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
jaccard = JaccardIndex(task='multiclass', num_classes=2)
jaccard(pred, target)

返回tensor(0.9660)

1
这应该是被接受的答案,因为使用经过测试的实现比自定义实现更加清晰和安全。 - CharlesG
2
请查看JaccardIndex - irudyak
@dominik-filipiak,您能否更新您的答案中的链接为irudyak提供的更新后的链接? - Saurav Maheshkar
1
@SauravMaheshkar 感谢您的评论,我已更新链接。 - Dominik Filipiak

10
我在某处找到了这个并为自己做了一些改动,如果我能再次找到它,我会发布链接的。如果有重复,请原谅。
这里的关键函数是名为 iou 的函数。包装函数 evaluate_performance 不是通用的,但它显示了在计算 IoU 之前需要迭代所有结果。
import torch 
import pandas as pd  # For filelist reading
import myPytorchDatasetClass  # Custom dataset class, inherited from torch.utils.data.dataset


def iou(pred, target, n_classes = 12):
  ious = []
  pred = pred.view(-1)
  target = target.view(-1)

  # Ignore IoU for background class ("0")
  for cls in xrange(1, n_classes):  # This goes from 1:n_classes-1 -> class "0" is ignored
    pred_inds = pred == cls
    target_inds = target == cls
    intersection = (pred_inds[target_inds]).long().sum().data.cpu()[0]  # Cast to long to prevent overflows
    union = pred_inds.long().sum().data.cpu()[0] + target_inds.long().sum().data.cpu()[0] - intersection
    if union == 0:
      ious.append(float('nan'))  # If there is no ground truth, do not include in evaluation
    else:
      ious.append(float(intersection) / float(max(union, 1)))
  return np.array(ious)


def evaluate_performance(net):
    # Dataloader for test data
    batch_size = 1  
    filelist_name_test = '/path/to/my/test/filelist.txt'
    data_root_test = '/path/to/my/data/'
    dset_test = myPytorchDatasetClass.CustomDataset(filelist_name_test, data_root_test)
    test_loader = torch.utils.data.DataLoader(dataset=dset_test,  
                                              batch_size=batch_size,
                                              shuffle=False,
                                              pin_memory=True)
    data_info = pd.read_csv(filelist_name_test, header=None)
    num_test_files = data_info.shape[0]  
    sample_size = num_test_files

    # Containers for results
    preds = Variable(torch.zeros((sample_size, 60, 36, 60)))
    gts = Variable(torch.zeros((sample_size, 60, 36, 60)))

    dataiter = iter(test_loader) 
    for i in xrange(sample_size):
        images, labels, filename = dataiter.next()
        images = Variable(images).cuda()
        labels = Variable(labels)
        gts[i:i+batch_size, :, :, :] = labels
        outputs = net(images)
        outputs = outputs.permute(0, 2, 3, 4, 1).contiguous()
        val, pred = torch.max(outputs, 4)
        preds[i:i+batch_size, :, :, :] = pred.cpu()
    acc = iou(preds, gts)
    return acc

6

假设你的输出形状为[32, 256, 256],其中32是小批量大小,256x256是图像的高度和宽度,并且标签也具有相同的形状。

然后你可以在一些重塑之后使用sklearn的jaccard_similarity_score

如果两者都是torch张量,则:

lbl = labels.cpu().numpy().reshape(-1)
target = output.cpu().numpy().reshape(-1)

现在:

from sklearn.metrics import jaccard_similarity_score as jsc
print(jsc(target,lbl))

1
在最新的Sklearn版本中,例如0.24.1,函数名称已更改为jaccard_score。https://scikit-learn.org/stable/modules/model_evaluation.html#jaccard-similarity-score - zong fan

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接