如何加载部分预训练的PyTorch模型?

11

我正在尝试在一个句子分类任务中运行一个pytorch模型。由于我正在处理医疗笔记,因此我正在使用ClinicalBert (https://github.com/kexinhuang12345/clinicalBERT),并希望使用其预训练权重。不幸的是,ClinicalBert模型仅将文本分类为1个二进制标签,而我有281个二进制标签。因此,我正在尝试实现这个代码 https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb 其中bert后的最终分类器为281。

如何从ClinicalBert模型中加载预训练的Bert权重而不加载分类权重?

天真地尝试从预训练的ClinicalBert权重加载权重时,我会得到以下错误:

size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).

我目前尝试使用pytorch_pretrained_bert软件包中的from_pretrained函数并弹出分类器的权重和偏置,代码如下:

def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
    ...
    if state_dict is None:
        weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
        state_dict = torch.load(weights_path, map_location='cpu')
    state_dict.pop('classifier.weight')
    state_dict.pop('classifier.bias')
    old_keys = []
    new_keys = []
    ...

我收到了以下错误信息:

INFO - modeling_diagnosis - BertForMultiLabelSequenceClassification的权重没有从预训练模型中初始化: ['classifier.weight', 'classifier.bias']

最终,我希望从clinicalBert预训练权重中加载bert嵌入,并随机初始化顶部分类器的权重。

1个回答

8

在加载之前从状态字典中删除键是一个好的开始。假设你正在使用nn.Module.load_state_dict来加载预训练的权重,那么你还需要将strict=False的参数设置为避免由于未知或丢失的键而产生错误。这将忽略在模型中不存在的状态字典中的条目(意外的键),更重要的是对于你来说,它将保留缺失的条目及其默认初始化(缺失的键)。为了安全起见,你可以检查该方法的返回值以验证所涉及的权重是否属于缺失的键,并且没有任何意外的键。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接