如何在Pytorch图像分类中实现早停?

4

Pytorch原生还是没有提供Early Stopping功能吗? - MJimitater
3个回答

3

这是我在每个时期做的事情。

val_loss += loss
val_loss = val_loss / len(trainloader)
if val_loss < min_val_loss:
  #Saving the model
  if min_loss > loss.item():
    min_loss = loss.item()
    best_model = copy.deepcopy(loaded_model.state_dict())
    print('Min loss %0.2f' % min_loss)
  epochs_no_improve = 0
  min_val_loss = val_loss

else:
  epochs_no_improve += 1
  # Check early stopping condition
  if epochs_no_improve == n_epochs_stop:
    print('Early stopping!' )
    loaded_model.load_state_dict(best_model)

我不确定这段代码是否完全正确(我在另一个网站上看到了类似的代码,但忘记了具体来源,因此无法提供参考链接。我只是稍作修改),希望你会发现它有用。如果我有错误,请指出。谢谢。


0
早停法的想法是通过在监控数量(例如,验证损失在几次迭代后不再下降)没有改善的情况下停止训练过程来避免过度拟合。早停法的最小实现需要3个组件:
  • best_score 变量用于存储验证损失的最佳值
  • counter 变量用于跟踪正在运行的迭代次数
  • patience 变量定义了允许在验证损失没有改善的情况下继续训练的时期数。如果 counter 超过此值,我们将停止训练过程。

伪代码如下:

# Define best_score, counter, and patience for early stopping:
best_score = None
counter = 0
patience = 10
path = ./checkpoints # user_defined path to save model

# Training loop:
for epoch in range(num_epochs):
    # Compute training loss
    loss = model(features,labels,train_mask)
    
    # Compute validation loss
    val_loss = evaluate(model, features, labels, val_mask)
    
    if best_score is None:
        best_score = val_loss
    else:
        # Check if val_loss improves or not.
        if val_loss < best_score:
            # val_loss improves, we update the latest best_score, 
            # and save the current model
            best_score = val_loss
            torch.save({'state_dict':model.state_dict()}, path)
        else:
            # val_loss does not improve, we increase the counter, 
            # stop training if it exceeds the amount of patience
            counter += 1
            if counter >= patience:
                break

# Load best model 
print('loading model before testing.')
model_checkpoint = torch.load(path)

model.load_state_dict(model_checkpoint['state_dict'])

acc = evaluate_test(model, features, labels, test_mask)    

我已经为Pytorch实现了一个通用的早停类(early stopping class),可以在我的一些项目中使用。它允许您选择任何感兴趣的验证数量(损失、准确性等)。如果您希望使用更高级的早停,则可以在存储库early-stopping中查看。还有一个示例笔记本可供参考。

0

请尝试以下代码。

     # Check early stopping condition
     if epochs_no_improve == n_epochs_stop:
        print('Early stopping!' )
        early_stop = True
        break
     else:
        continue
     break
if early_stop:
    print("Stopped")
    break

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接