从 PyTorch lightning 模型中检索 PyTorch 模型

5

我已经训练了一个类似于这样的PyTorch Lightning模型:

In [16]: MLP
Out[16]:
DecoderMLP(
  (loss): RMSE()
  (logging_metrics): ModuleList(
    (0): SMAPE()
    (1): MAE()
    (2): RMSE()
    (3): MAPE()
    (4): MASE()
  )
  (input_embeddings): MultiEmbedding(
    (embeddings): ModuleDict(
      (LCLid): Embedding(5, 4)
      (sun): Embedding(5, 4)
      (day_of_week): Embedding(7, 5)
      (month): Embedding(12, 6)
      (year): Embedding(3, 3)
      (holidays): Embedding(2, 1)
      (BusinessDay): Embedding(2, 1)
      (day): Embedding(31, 11)
      (hour): Embedding(24, 9)
    )
  )
  (mlp): FullyConnectedModule(
    (sequential): Sequential(
      (0): Linear(in_features=60, out_features=435, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.13371112461182535, inplace=False)
      (3): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (4): Linear(in_features=435, out_features=435, bias=True)
      (5): ReLU()
      (6): Dropout(p=0.13371112461182535, inplace=False)
      (7): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (8): Linear(in_features=435, out_features=435, bias=True)
      (9): ReLU()
      (10): Dropout(p=0.13371112461182535, inplace=False)
      (11): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (12): Linear(in_features=435, out_features=435, bias=True)
      (13): ReLU()
      (14): Dropout(p=0.13371112461182535, inplace=False)
      (15): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (16): Linear(in_features=435, out_features=435, bias=True)
      (17): ReLU()
      (18): Dropout(p=0.13371112461182535, inplace=False)
      (19): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (20): Linear(in_features=435, out_features=435, bias=True)
      (21): ReLU()
      (22): Dropout(p=0.13371112461182535, inplace=False)
      (23): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (24): Linear(in_features=435, out_features=435, bias=True)
      (25): ReLU()
      (26): Dropout(p=0.13371112461182535, inplace=False)
      (27): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (28): Linear(in_features=435, out_features=435, bias=True)
      (29): ReLU()
      (30): Dropout(p=0.13371112461182535, inplace=False)
      (31): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (32): Linear(in_features=435, out_features=435, bias=True)
      (33): ReLU()
      (34): Dropout(p=0.13371112461182535, inplace=False)
      (35): LayerNorm((435,), eps=1e-05, elementwise_affine=True)
      (36): Linear(in_features=435, out_features=1, bias=True)
    )
  )
)

我需要相应的PyTorch模型来在我的其他应用程序中使用。

有没有简单的方法可以做到这一点?

我想保存检查点,但我不知道如何保存。

请帮忙一下,谢谢。


您是想简单地加载Lightning检查点以使用训练好的模型,还是需要在没有Lightning依赖的情况下使用训练好的模型? - Matthew R.
@MatthewR。我想在没有Lightning依赖的情况下使用训练好的模型。 - lalaland
1个回答

5
你可以手动保存LightningModuletorch.nn.Module的权重。像这样:
trainer.fit(model, trainloader, valloader)

torch.save(
    model.input_embeddings.state_dict(),
    "input_embeddings.pt"
)
torch.save(model.mlp.state_dict(), "mlp.pt")

然后可以在不需要Lightning的情况下加载:
# create the "blank" networks like they
# were created in the Lightning Module
input_embeddings = MultiEmbedding(...)
mlp = FullyConnectedModule(...)

# Load the models for inference
input_embeddings.load_state_dict(
    torch.load("input_embeddings.pt")
)
input_embeddings.eval()

mlp.load_state_dict(
    torch.load("mlp.pt")
)
mlp.eval()

如需了解有关保存和加载 PyTorch 模块的更多信息,请参阅 PyTorch 文档中的 Saving and Loading Models: Saving & Loading Model for Inference

由于 Lightning 自动将检查点保存到磁盘(如果使用默认的 Tensorboard 记录器,则请检查 lightning_logs 文件夹),因此您还可以加载预训练的 LightningModule,然后保存状态字典,而无需重复所有训练。与先前的代码中调用 trainer.fit 不同,尝试:

model = DecoderMLP.load_from_checkpoint("path/to/checkpoint.ckpt")

谢谢您的帮助。问题是Lightning上的模型并不是我实现的。另外,我有很多模型。MLP是较简单的一种,但我有一些像TFT这样的庞大模型,我没有时间去实现它们。是否有一种方法可以在Lightning中解析架构,然后在torch中实现它呢?希望这有意义。 - lalaland
1
是的,通过对“LightningModule”属性和参数进行某种迭代可能是可能的,但对于大型模型来说非常困难。不幸的是,考虑到您的用例,似乎唯一简单的解决方案是保留Lightning依赖项并以这种方式加载/使用模型。是否有避免使用Lightning的特定原因? - Matthew R.
好的,谢谢你的回答。我很感激。 - lalaland

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接