运行时错误:加载state_dict时出现错误。

3
我有以下的PyTorch模型:
import math
from abc import abstractmethod

import torch.nn as nn


class AlexNet3D(nn.Module):
    @abstractmethod
    def get_head(self):
        pass

    def __init__(self, input_size):
        super().__init__()
        self.input_size = input_size
        self.features = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=0),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=3, stride=3),

            nn.Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=0),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=3, stride=3),

            nn.Conv3d(128, 192, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(192),
            nn.ReLU(inplace=True),

            nn.Conv3d(192, 192, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(192),
            nn.ReLU(inplace=True),

            nn.Conv3d(192, 128, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=3, stride=3),
        )

        self.classifier = self.get_head()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        xp = self.features(x)
        x = xp.view(xp.size(0), -1)
        x = self.classifier(x)
        return [x, xp]


class AlexNet3DDropoutRegression(AlexNet3D):
    def get_head(self):
        return nn.Sequential(nn.Dropout(),
                             nn.Linear(self.input_size, 64),
                             nn.ReLU(inplace=True),
                             nn.Dropout(),
                             nn.Linear(64, 1),
                             )

我正在这样初始化模型:
def init_model(self):
    model = AlexNet3DDropoutRegression(4608)
    if self.use_cuda:
        log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model = model.to(self.device)
    return model

训练完成后,我会这样保存模型:

    torch.save(self.model.state_dict(), self.cli_args.model_save_location)

我然后尝试加载保存的模型:

import torch
from reprex.models import AlexNet3DDropoutRegression


model_save_location = "/home/feczk001/shared/data/AlexNet/LoesScoring/loes_scoring_01.pt"

model = AlexNet3DDropoutRegression(4608)
model.load_state_dict(torch.load(model_save_location,
                                 map_location='cpu'))

但是我收到了以下错误:

RuntimeError: Error(s) in loading state_dict for AlexNet3DDropoutRegression:
    Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.4.weight", "features.4.bias", "features.5.weight", "features.5.bias", "features.5.running_mean", "features.5.running_var", "features.8.weight", "features.8.bias", "features.9.weight", "features.9.bias", "features.9.running_mean", "features.9.running_var", "features.11.weight", "features.11.bias", "features.12.weight", "features.12.bias", "features.12.running_mean", "features.12.running_var", "features.14.weight", "features.14.bias", "features.15.weight", "features.15.bias", "features.15.running_mean", "features.15.running_var", "classifier.1.weight", "classifier.1.bias", "classifier.4.weight", "classifier.4.bias". 
    Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.1.weight", "module.features.1.bias", "module.features.1.running_mean", "module.features.1.running_var", "module.features.1.num_batches_tracked", "module.features.4.weight", "module.features.4.bias", "module.features.5.weight", "module.features.5.bias", "module.features.5.running_mean", "module.features.5.running_var", "module.features.5.num_batches_tracked", "module.features.8.weight", "module.features.8.bias", "module.features.9.weight", "module.features.9.bias", "module.features.9.running_mean", "module.features.9.running_var", "module.features.9.num_batches_tracked", "module.features.11.weight", "module.features.11.bias", "module.features.12.weight", "module.features.12.bias", "module.features.12.running_mean", "module.features.12.running_var", "module.features.12.num_batches_tracked", "module.features.14.weight", "module.features.14.bias", "module.features.15.weight", "module.features.15.bias", "module.features.15.running_mean", "module.features.15.running_var", "module.features.15.num_batches_tracked", "module.classifier.1.weight", "module.classifier.1.bias", "module.classifier.4.weight", "module.classifier.4.bias". 

这里出了什么问题?

3个回答

7
问题在于您使用 DataParallel 训练模型,然后尝试在非并行网络中重新加载模型。 DataParallel 是一个包装类,它使原始模型(一个 torch.nn.module 对象)成为 DataParallel 对象的类属性,命名为 module。这个问题已经在 pytorch discussstack overflowgithub 上得到了解决,因此我不会在这里重复详细说明,但您可以通过以下方式修复它:

  1. 将模型作为 DataParallel 对象进行保存和加载,在您想要用于推理时,这种方法可能会失效。

  2. 可以将 DataParallel 对象的 module state_dict 进行保存:

    # 保存 DataParallel 对象的 state dict
    torch.save(model.module.state_dict(), path)
    
    
     .... 稍后
    # 在非并行模型上重新加载权重
    model.load_state_dict(torch.load(path))
    

这里是一个简单的例子:

model = AlexNet3DDropoutRegression(4608) # on cpu
model = nn.DataParallel(model)
model = model.to("cuda") # DataParallel object on GPU(s)


torch.save(model.module.state_dict(),"example_path.pt")

del model
model = AlexNet3DDropoutRegression(4608)

ret = model.load_state_dict(torch.load("example_path.pt")) 
print(ret) 

输出:

>>> <All keys successfully matched>
  1. 如果你已经保存了 state_dict 并需要重新加载,那么更有用的方法是,你可以加载 DataParallel 模型的 state_dict,将键名重新映射以排除 "module",然后使用重新映射的 state_dict。类似这样:
incompatible_state_dict = torch.load("DataParallel_save_file.pt")
state_dict = {}
for key in incompatible_state_dict():
    state_dict[key.split("module.")[-1]] = incompatible_state_dict[key]

 ret = model.load_state_dict(state_dict)
 print(ret)

输出:

>>> <All keys successfully matched>

我在torch.save()map_location参数中收到了“意外参数”的警告。 - Paul Reiners
尝试移除那个参数,看看是否会出现错误。我相信你不需要它。 - DerekG
我尝试删除那个参数,但我又回到了之前的情况。 - Paul Reiners
3
我能够使用您上面的代码和此答案中列出的附加行保存和加载模型权重。关键是,如果您的模型包装在DataParallel对象中,则需要使用model.module.state_dict()来访问参数,否则只需使用model.state_dict()。因此,根据您将检查点加载和保存为单个nn.module还是作为DataParallel对象,您将需要使用适当的属性来访问模型权重。 - DerekG

4

nn.DataParallel 是一个包装类,它会在状态字典中的所有键前添加 "module." 前缀。因此,在意外的键中你会看到 module.featuresmodule.classifier。要解决这个问题,你只需要在加载模型的 state_dict 时删除 module. 前缀即可。

model = AlexNet3DDropoutRegression(4608)
model_save_location = "/home/feczk001/shared/data/AlexNet/LoesScoring/loes_scoring_01.pt"

state_dict = torch.load(model_save_location, map_location='cpu')
model.load_state_dict({k.replace("module.", ""): v for k, v in state_dict.items()})

1
你遇到的问题是,你正在从一个已经训练好的DataParallel模型中加载状态字典,然后创建一个不使用DataParallel的新模型。module在使用DataParallel和PyTorch时已经被加上了前缀。所以如果你删除module前缀,就可以解决问题。除非你想在新模型初始化时使用DataParallel,否则最好只是删除module前缀。
以下代码片段可以实现这一点:
model = AlexNet3DDropoutRegression(4608)
state_dict = torch.load(model_save_location, map_location='cpu')
new_state_dict = {}
for key in state_dict.keys():
    new_key = key.replace("module.", "")
    new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接