如何确定 PyTorch Faster RCNN 的验证损失?

4
我按照这个教程进行了目标检测: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html, 以及他们包含以下train_one_epochevaluate函数的GitHub存储库: https://github.com/pytorch/vision/blob/main/references/detection/engine.py。 然而,我想在验证过程中计算损失。 我为评估损失实现了此操作,在这种情况下,需要打开model.train()以获取损失。
@torch.no_grad()
def evaluate_loss(model, data_loader, device):
    val_loss = 0
    model.train()
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        val_loss += losses_reduced
  
  validation_loss = val_loss/ len(data_loader)    
  return validation_loss

然后我会将它放在学习率调度器步骤之后,在我的循环中:

 for epoch in range(args.num_epochs):
        # train for one epoch, printing every 10 iterations
        train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
    
        # update the learning rate
        lr_scheduler.step()

        validation_loss = evaluate_loss(model, valid_data_loader, device=device)

        # evaluate on the test dataset
        evaluate(model, valid_data_loader, device=device)

这看起来正确吗?会不会干扰训练或产生不准确的验证损失?

如果可以,使用这个方法,是否有一种简单的方式应用早期停止验证损失?

我正在考虑在上面显示的评估模型函数之后添加类似于以下内容:

torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'validation loss': valid_loss,
            }, PATH)

我也希望在每个epoch保存模型以进行检查点操作。然而,我需要确定验证“损失”以保存“最佳”模型。


1
如果没有批量归一化或丢弃层,您可以仅在model.train上进行训练,否则这些层将在评估期间更新/开启(分别)。 - jhso
1
你可以通过复制pytorch函数来创建自己的前向函数进行评估。在该链接中,你只需要更改第110行以返回检测和损失。 - jhso
我查看了你提供的PyTorch函数,它已经返回了检测和损失。我不认为我完全理解你的意思。我相信你有正确的答案......你能否在下面发布一个更详细的答案?这是我正在努力理解的问题,看起来它是torch vision RCNN的常见问题。 - Ze0ruso
3个回答

7
所以事实证明,当设置model.eval()时,pytorch fasterrcnn的所有阶段均不返回损失。但是,在评估模式下,您可以手动使用forward代码生成损失。
from typing import Tuple, List, Dict, Optional
import torch
from torch import Tensor
from collections import OrderedDict
from torchvision.models.detection.roi_heads import fastrcnn_loss
from torchvision.models.detection.rpn import concat_box_prediction_layers
def eval_forward(model, images, targets):
    # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
    """
    Args:
        images (list[Tensor]): images to be processed
        targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
    Returns:
        result (list[BoxList] or dict[Tensor]): the output from the model.
            It returns list[BoxList] contains additional fields
            like `scores`, `labels` and `mask` (for Mask R-CNN models).
    """
    model.eval()

    original_image_sizes: List[Tuple[int, int]] = []
    for img in images:
        val = img.shape[-2:]
        assert len(val) == 2
        original_image_sizes.append((val[0], val[1]))

    images, targets = model.transform(images, targets)

    # Check for degenerate boxes
    # TODO: Move this to a function
    if targets is not None:
        for target_idx, target in enumerate(targets):
            boxes = target["boxes"]
            degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
            if degenerate_boxes.any():
                # print the first degenerate box
                bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
                degen_bb: List[float] = boxes[bb_idx].tolist()
                raise ValueError(
                    "All bounding boxes should have positive height and width."
                    f" Found invalid box {degen_bb} for target at index {target_idx}."
                )

    features = model.backbone(images.tensors)
    if isinstance(features, torch.Tensor):
        features = OrderedDict([("0", features)])
    model.rpn.training=True
    #model.roi_heads.training=True


    #####proposals, proposal_losses = model.rpn(images, features, targets)
    features_rpn = list(features.values())
    objectness, pred_bbox_deltas = model.rpn.head(features_rpn)
    anchors = model.rpn.anchor_generator(images, features_rpn)

    num_images = len(anchors)
    num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
    num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
    objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
    # apply pred_bbox_deltas to anchors to obtain the decoded proposals
    # note that we detach the deltas because Faster R-CNN do not backprop through
    # the proposals
    proposals = model.rpn.box_coder.decode(pred_bbox_deltas.detach(), anchors)
    proposals = proposals.view(num_images, -1, 4)
    proposals, scores = model.rpn.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

    proposal_losses = {}
    assert targets is not None
    labels, matched_gt_boxes = model.rpn.assign_targets_to_anchors(anchors, targets)
    regression_targets = model.rpn.box_coder.encode(matched_gt_boxes, anchors)
    loss_objectness, loss_rpn_box_reg = model.rpn.compute_loss(
        objectness, pred_bbox_deltas, labels, regression_targets
    )
    proposal_losses = {
        "loss_objectness": loss_objectness,
        "loss_rpn_box_reg": loss_rpn_box_reg,
    }

    #####detections, detector_losses = model.roi_heads(features, proposals, images.image_sizes, targets)
    image_shapes = images.image_sizes
    proposals, matched_idxs, labels, regression_targets = model.roi_heads.select_training_samples(proposals, targets)
    box_features = model.roi_heads.box_roi_pool(features, proposals, image_shapes)
    box_features = model.roi_heads.box_head(box_features)
    class_logits, box_regression = model.roi_heads.box_predictor(box_features)

    result: List[Dict[str, torch.Tensor]] = []
    detector_losses = {}
    loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
    detector_losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
    boxes, scores, labels = model.roi_heads.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
    num_images = len(boxes)
    for i in range(num_images):
        result.append(
            {
                "boxes": boxes[i],
                "labels": labels[i],
                "scores": scores[i],
            }
        )
    detections = result
    detections = model.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]
    model.rpn.training=False
    model.roi_heads.training=False
    losses = {}
    losses.update(detector_losses)
    losses.update(proposal_losses)
    return losses, detections

运行这段代码给我返回:

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
losses, detections = eval_forward(model,torch.randn([1,3,300,300]),[{'boxes':torch.tensor([[100,100,200,200]]),'labels':torch.tensor([0])}])

{'loss_classifier': tensor(0.6594, grad_fn=<NllLossBackward0>),
'loss_box_reg': tensor(0., grad_fn=<DivBackward0>),
 'loss_objectness': tensor(0.5108, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
 'loss_rpn_box_reg': tensor(0.0160, grad_fn=<DivBackward0>)}

1
我给你一个可以使用的函数... - jhso
1
那么上面的函数是从generalized_rcnn修改而来的。你试过了吗? - jhso
1
所以我写的代码...你可以将它复制到你的本地目录/脚本中。你不需要修改任何torch的代码,因为我已经为你做好了。 - jhso
1
不行,因为我已经将模型作为函数的输入之一,所以您需要迭代您的数据加载器并逐个馈送每个批次,同时增加您的损失累积变量。 - jhso
1
@Perry45 不要将此与 model(imgs) 进行比较,因为此代码应仅用于在测试/验证集上生成模型的损失。使用此方法生成的预测将具有前景/背景采样和其他训练时增强以生成损失。 - jhso
显示剩余11条评论

1
非常感谢您的耐心等待。下面是一个迭代数据加载器的代码片段。我认为我已经理解了您的意思,但从下面的代码中打印出来的损失值为空字典:
@torch.no_grad()
def evaluate_loss(model, data_loader, device):
    val_loss = 0
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        #USE PROVIDED CODE to get losses and detections
        losses, detections = eval_forward(model, images, targets)

        print(losses) # empty {}

         val_loss += sum(loss for loss in losses.values())

    validation_loss = val_loss/ len(data_loader)    
    return validation_loss

当我打印损失和检测结果时,我得到了以下内容:
{} [{'boxes': tensor([[  0.0000, 430.0531, 364.2619, 512.0000],
        [  6.8726, 455.9226, 256.0113, 509.0516],
        [  5.7750, 227.0236, 138.1525, 503.0216],
        [  0.0000, 275.2110,  87.6766, 512.0000],
        [ 55.3590, 484.3553, 311.3914, 512.0000],
        [ 41.9545, 370.1071, 431.6385, 500.5055],
        [  0.0000, 391.8048, 187.7228, 512.0000],
        [501.2419, 187.9812, 511.2767, 201.9233],
        [507.1944, 195.7916, 511.5490, 216.8658],
        [173.8539, 460.3328, 448.6479, 506.3229],
        [  0.0000, 200.4993, 224.5978, 455.6439],
        [432.5095, 107.3605, 448.2870, 123.3097],
        [  0.0000, 484.3896, 181.2187, 512.0000],
        [252.8410, 352.4666, 269.2491, 364.2188],
        [141.6757, 485.4147, 439.0354, 512.0000],
        [252.6323, 341.7145, 267.7503, 353.9413],
        [134.9624, 314.2813, 474.5851, 492.6868],
        [505.2639, 237.3413, 511.8117, 262.1838],
        [  0.0000, 297.2654, 370.9958, 492.1260],
        [506.8980, 181.4306, 511.8102, 204.6986],
        [171.3477, 413.2979, 487.6665, 512.0000],
        [507.0528, 298.5904, 511.8441, 309.8073],
        [336.4479, 267.7834, 499.2108, 496.2349],
        [178.1360, 341.3546, 367.1203, 504.6978],
        [244.6255, 218.8507, 257.6999, 231.4108],
        [504.0644, 254.3425, 511.8181, 268.0185],
        [  0.0000, 365.2629,  39.0588, 512.0000],
        [258.7524, 340.9509, 271.9611, 353.5555],
        [507.1984, 443.6097, 511.7004, 455.8767],
        [346.1955, 170.9065, 358.2302, 184.1580],
        [ 50.2086, 324.4587, 251.0680, 512.0000],
        [198.5728, 322.8210, 209.8158, 330.6772],
        [498.2428, 141.8683, 511.1887, 224.0274],
        [297.8328, 483.9214, 500.6504, 512.0000],
        [383.7580, 302.3506, 406.5758, 328.4388],
        [190.7700, 319.5901, 203.9809, 330.4897],
        [248.1737, 341.2397, 272.0346, 364.2649],
        [ 41.9480, 182.3307, 309.7350, 511.4400],
        [507.6814, 465.5771, 511.6959, 478.4059],
        [  0.0000, 414.7599,  16.6887, 512.0000],
        [  0.0000, 495.9020,   9.1763, 512.0000],
        [506.0956, 484.8349, 511.6204, 508.3524],
        [  0.0000, 484.2805,  14.1195, 512.0000],
        [186.2599, 231.2097, 451.8763, 466.7952],
        [465.1697, 499.5819, 508.8633, 512.0000],
        [359.1404, 416.1848, 416.8053, 512.0000],
        [444.5928, 200.7507, 457.7525, 216.0354],
        [348.6382, 146.4818, 362.1615, 155.7809],
        [288.0855, 181.4522, 306.9987, 202.8014],
        [138.3017, 199.5426, 152.1866, 214.0261],
        [ 54.3134, 322.8700,  66.6056, 339.6511],
        [236.9178, 176.1253, 256.1872, 195.2987],
        [183.0305, 224.6637, 198.1654, 238.4647],
        [255.3874, 337.9686, 452.8956, 505.8088],
        [195.6607, 342.5625, 207.6055, 351.6043],
        [478.7965, 262.2610, 510.4778, 512.0000],
        [507.0534,  62.8041, 511.7828,  83.3675],
        [506.9258, 247.0326, 511.7821, 269.0636],
        [  0.0000, 482.6279,  39.7247, 512.0000],
        [  0.0000, 400.6234,  62.0636, 497.9158],
        [504.7887, 295.1768, 511.6837, 314.4619],
        [503.7539, 444.5576, 511.6874, 469.6237],
        [420.8303, 139.0130, 435.5850, 155.6219],
        [  0.0000, 169.4536,  35.6173, 512.0000],
        [505.5238, 216.9875, 511.8623, 244.7741],
        [493.3357, 183.2157, 510.4757, 225.7995],
        [283.5856, 184.4567, 294.6422, 199.1284],
        [506.1086, 172.9610, 511.7372, 195.6782],
        [421.7606, 478.9979, 506.9432, 512.0000],
        [  0.0000, 128.1171, 182.0242, 372.1508],
        [266.6456, 212.4419, 285.0941, 230.3711],
        [242.4399, 337.2843, 292.0536, 369.6913],
        [490.5333, 151.4534, 511.3717, 199.9196],
        [195.0700, 317.0647, 208.6026, 328.3253],
        [506.5237, 166.3083, 511.7383, 186.4610],
        [285.0119, 210.5486, 302.8143, 227.0892],
        [507.7259, 159.7037, 511.7627, 177.6721],
        [507.2086, 409.5898, 511.7660, 443.1966],
        [486.4733,   1.5067, 511.0473,  32.8377],
        [499.7045, 410.5609, 511.2081, 495.3992],
        [381.5405, 282.1667, 394.4013, 292.7220],
        [398.5074,  97.8511, 408.5006, 109.4040],
        [286.4212,  66.7245, 305.3555,  84.7535],
        [ 53.2904, 198.9514,  72.6522, 218.6958],
        [  0.0000, 119.1250, 352.9160, 404.2254],
        [305.2835, 262.8656, 322.0334, 282.8750],
        [ 67.7342, 107.0263,  79.3835, 116.1997],
        [504.5052, 328.6933, 511.7248, 354.2790],
        [505.5066, 454.7970, 511.6003, 479.1691],
        [297.2463, 179.5240, 459.4996, 500.3919],
        [505.9551, 116.8015, 511.8934, 139.2066],
        [ 51.7288, 143.0008,  70.2031, 162.0272],
        [281.4141, 178.7466, 292.6686, 195.8384],
        [329.5997, 233.1259, 344.1964, 247.8056],
        [308.4427, 105.4068, 324.9741, 120.8449],
        [173.9055, 208.1558, 187.9732, 223.4990],
        [506.5709, 396.8288, 511.6976, 427.8991],
        [281.4510, 187.4271, 317.5686, 229.1852],
        [395.2721, 351.2404, 407.8893, 365.8526],
        [501.4947, 463.5199, 511.3037, 476.1774]]), 'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1]), 'scores': tensor([0.7932, 0.7808, 0.7726, 0.7688, 0.7644, 0.7624, 0.7563, 0.7557, 0.7481,
        0.7428, 0.7417, 0.7415, 0.7414, 0.7403, 0.7378, 0.7354, 0.7293, 0.7268,
        0.7256, 0.7235, 0.7196, 0.7195, 0.7192, 0.7175, 0.7163, 0.7160, 0.7130,
        0.7126, 0.7122, 0.7120, 0.7120, 0.7095, 0.7095, 0.7094, 0.7083, 0.7065,
        0.7048, 0.7042, 0.7041, 0.7038, 0.7006, 0.7005, 0.6998, 0.6997, 0.6974,
        0.6974, 0.6969, 0.6963, 0.6958, 0.6950, 0.6949, 0.6946, 0.6946, 0.6936,
        0.6925, 0.6915, 0.6897, 0.6897, 0.6884, 0.6880, 0.6862, 0.6861, 0.6858,
        0.6855, 0.6853, 0.6848, 0.6844, 0.6836, 0.6827, 0.6823, 0.6814, 0.6808,
        0.6797, 0.6784, 0.6770, 0.6769, 0.6766, 0.6764, 0.6764, 0.6755, 0.6754,
        0.6735, 0.6733, 0.6720, 0.6715, 0.6713, 0.6712, 0.6697, 0.6693, 0.6687,
        0.6673, 0.6671, 0.6670, 0.6669, 0.6663, 0.6658, 0.6658, 0.6658, 0.6657,
        0.6654])}]

如果第一个字典中未显示计算损失,则需要进行计算。


1
好的,现在我们有进展了!你能在eval_forward中添加print(detections, detector_losses)并查看它的输出吗? - jhso
1
我已经修复了我的错误,请查看我上面修改后的代码。 - jhso

0
通过遵循 @jhso 提供的代码,我可以通过查看损失字典确定验证损失,计算所有这些损失,并在最后通过数据加载器的长度对它们进行平均。
def evaluate_loss(model, data_loader, device):
    val_loss = 0
    with torch.no_grad():
      for images, targets in data_loader:
          images = list(image.to(device) for image in images)
          targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
          losses_dict, detections = eval_forward(model, images, targets)
         
          losses = sum(loss for loss in loss_dict.values())

          val_loss += losses
          
    validation_loss = val_loss/ len(data_loader)    
    return validation_loss

然后我将其放入以下循环进行训练和评估:

import utils
from engine import train_one_epoch, evaluate


for epoch in range(num_epochs):
        # train for one epoch, printing every 10 iterations
        train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
        # update the learning rate
        lr_scheduler.step()
        # new function that determines validation loss
        validation_loss  = evaluate_loss(model, valid_data_loader, device=device)
        print(validation_loss)

        # evaluate on the test dataset
        evaluate(model, valid_data_loader, device=device)

我认为这是正确的。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接