我是Pytorch和机器学习方面的新手,正在遵循这个教程: https://www.learnopencv.com/image-classification-using-transfer-learning-in-pytorch/ 并使用我的自定义数据集。然后我在这个教程中遇到了同样的问题,但我不知道如何在Pytorch中实现早期停止,如果您有更好的方法而不需要创建早期停止过程,请告诉我。
我是Pytorch和机器学习方面的新手,正在遵循这个教程: https://www.learnopencv.com/image-classification-using-transfer-learning-in-pytorch/ 并使用我的自定义数据集。然后我在这个教程中遇到了同样的问题,但我不知道如何在Pytorch中实现早期停止,如果您有更好的方法而不需要创建早期停止过程,请告诉我。
这是我在每个时期做的事情。
val_loss += loss
val_loss = val_loss / len(trainloader)
if val_loss < min_val_loss:
#Saving the model
if min_loss > loss.item():
min_loss = loss.item()
best_model = copy.deepcopy(loaded_model.state_dict())
print('Min loss %0.2f' % min_loss)
epochs_no_improve = 0
min_val_loss = val_loss
else:
epochs_no_improve += 1
# Check early stopping condition
if epochs_no_improve == n_epochs_stop:
print('Early stopping!' )
loaded_model.load_state_dict(best_model)
我不确定这段代码是否完全正确(我在另一个网站上看到了类似的代码,但忘记了具体来源,因此无法提供参考链接。我只是稍作修改),希望你会发现它有用。如果我有错误,请指出。谢谢。
best_score
变量用于存储验证损失的最佳值counter
变量用于跟踪正在运行的迭代次数patience
变量定义了允许在验证损失没有改善的情况下继续训练的时期数。如果 counter
超过此值,我们将停止训练过程。伪代码如下:
# Define best_score, counter, and patience for early stopping:
best_score = None
counter = 0
patience = 10
path = ./checkpoints # user_defined path to save model
# Training loop:
for epoch in range(num_epochs):
# Compute training loss
loss = model(features,labels,train_mask)
# Compute validation loss
val_loss = evaluate(model, features, labels, val_mask)
if best_score is None:
best_score = val_loss
else:
# Check if val_loss improves or not.
if val_loss < best_score:
# val_loss improves, we update the latest best_score,
# and save the current model
best_score = val_loss
torch.save({'state_dict':model.state_dict()}, path)
else:
# val_loss does not improve, we increase the counter,
# stop training if it exceeds the amount of patience
counter += 1
if counter >= patience:
break
# Load best model
print('loading model before testing.')
model_checkpoint = torch.load(path)
model.load_state_dict(model_checkpoint['state_dict'])
acc = evaluate_test(model, features, labels, test_mask)
请尝试以下代码。
# Check early stopping condition
if epochs_no_improve == n_epochs_stop:
print('Early stopping!' )
early_stop = True
break
else:
continue
break
if early_stop:
print("Stopped")
break