运行时错误:为BertModel加载state_dict时出现错误。

3

我使用hugging face transformer库微调一款BERT模型,并在云GPU上进行训练。然后像下面这样保存模型和分词器:

model.save_pretrained('/saved_model/')
torch.save(best_model.state_dict(), '/saved_model/model')
tokenizer.save_pretrained('/saved_model/')

我在电脑上下载了 saved_model 目录。然后我在电脑上按照以下方式加载模型/令牌器。
import torch
from transformers import *
tokenizer = BertTokenizer.from_pretrained('./saved_model/')
config = BertConfig('./saved_model/config.json')
model = BertModel(config)
model.load_state_dict(torch.load('./saved_model/pytorch_model.bin', map_location=torch.device('cpu')))
model.eval()

但是在model.load_state_dict这行代码中,它会抛出下面的错误:

RuntimeError: Error(s) in loading state_dict for BertModel:
    Missing key(s) in state_dict:

它列出了一堆在state_dict中显然缺失的键。

我是pytorch的新手,不确定发生了什么。很可能我没有正确保存模型。

请提供建议。

1个回答

3
如您所知,PyTorch模块的state_dict是一个有序字典。当您尝试从state_dict加载模块的权重时,会出现缺少键的错误提示,这意味着state_dict不包含这些键。在这种情况下,我建议您采取以下行动:
  1. 检查state_dict中存在哪些键。保存部分键看起来是不可能的。
  2. 确保加载了正确的配置。否则,如果您训练的BertModel和要加载权重的新BertModel不同,则会出现此错误。
  3. 最后,在保存模型时,请确保将所有层的参数保存在文件中。语句torch.save(best_model.state_dict(), '/saved_model/model')在我看来是正确的,但请确保best_model.state_dict()包含所有预期的键。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接