加载 PyTorch Lightning 训练的检查点。

5

我正在使用 PyTorch Lightning 版本 1.4.0,并为数据集定义了以下类:

class CustomTrainDataset(Dataset):
    '''
    Custom PyTorch Dataset for training
    
    Args:
        data (pd.DataFrame) - DF containing product info (and maybe also ratings)
        all_itemIds (list) - Python3 list containing all Item IDs
    '''
    
    def __init__(self, data, all_orderIds):
        self.users, self.items, self.labels = self.get_dataset(data, all_orderIds)
    
    def __len__(self):
        return len(self.users)
  
    def __getitem__(self, idx):
        return self.users[idx], self.items[idx], self.labels[idx]
    
    def get_dataset(self, data, all_orderIds):
        users, items, labels = [], [], []
        user_item_set = set(zip(train_ratings['CustomerID'], train_ratings['ItemCode']))

        num_negatives = 7
        for u, i in user_item_set:
            users.append(u)
            items.append(i)
            labels.append(1)
            for _ in range(num_negatives):
                negative_item = np.random.choice(all_itemIds)
                while (u, negative_item) in user_item_set:
                    negative_item = np.random.choice(all_itemIds)
                users.append(u)
                items.append(negative_item)
                labels.append(0)

        return torch.tensor(users), torch.tensor(items), torch.tensor(labels)

接着是PL类:

class NCF(pl.LightningModule):
    '''
    Neural Collaborative Filtering (NCF)
    
    Args:
        num_users (int): Number of unique users
        num_items (int): Number of unique items
        data (pd.DataFrame): Dataframe containing the food ratings for training
        all_orderIds (list): List containing all orderIds (train + test)
    '''
    
    def __init__(self, num_users, num_items, data, all_itemIds):
    # def __init__(self, num_users, num_items, ratings, all_movieIds):
        super().__init__()
        self.user_embedding = nn.Embedding(num_embeddings = num_users, embedding_dim = 8)
        # self.user_embedding = nn.Embedding(num_embeddings = num_users, embedding_dim = 10)
        self.item_embedding = nn.Embedding(num_embeddings = num_items, embedding_dim = 8)
        # self.item_embedding = nn.Embedding(num_embeddings = num_items, embedding_dim = 10)
        self.fc1 = nn.Linear(in_features = 16, out_features = 64)
        # self.fc1 = nn.Linear(in_features = 20, out_features = 64)
        self.fc2 = nn.Linear(in_features = 64, out_features = 64)
        self.fc3 = nn.Linear(in_features = 64, out_features = 32)
        self.output = nn.Linear(in_features = 32, out_features = 1)
        self.data = data
        # self.ratings = ratings
        # self.all_movieIds = all_movieIds
        self.all_orderIds = all_orderIds
        
    def forward(self, user_input, item_input):
        
        # Pass through embedding layers
        user_embedded = self.user_embedding(user_input)
        item_embedded = self.item_embedding(item_input)

        # Concat the two embedding layers
        vector = torch.cat([user_embedded, item_embedded], dim = -1)

        # Pass through dense layer
        vector = nn.ReLU()(self.fc1(vector))
        vector = nn.ReLU()(self.fc2(vector))
        vector = nn.ReLU()(self.fc3(vector))

        # Output layer
        pred = nn.Sigmoid()(self.output(vector))

        return pred
    
    def training_step(self, batch, batch_idx):
        user_input, item_input, labels = batch
        predicted_labels = self(user_input, item_input)
        loss = nn.BCELoss()(predicted_labels, labels.view(-1, 1).float())
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

    def train_dataloader(self):
        return DataLoader(
            ChupsTrainDataset(
                self.data, self.all_orderIds
            ),
            batch_size = 32, num_workers = 2
            # Google Colab's suggested max number of worker in current
            # system is 2 and not 4.
        )

print(f"num_users = {num_users}, num_items = {num_items} & all_itemIds = {len(all_itemIds)}")
# num_users = 12958, num_items = 511238 & all_itemIds = 9114

# Initialize NCF model-
model = NCF(num_users, num_items, train_ratings, all_itemIds)

trainer = pl.Trainer(
    max_epochs = 75, gpus = 1,
    # max_epochs = 5,
    reload_dataloaders_every_n_epochs = True,
    # reload_dataloaders_every_epoch = True,   # deprecated!
    progress_bar_refresh_rate = 50,
    logger = False, checkpoint_callback = False)

trainer.fit(model)

# Save trained model as a checkpoint-
trainer.save_checkpoint("NCF_Trained.ckpt")

为了加载保存的检查点,我尝试了以下方法:
trained_model = NCF.load_from_checkpoint(
    "NCF_Trained.ckpt", num_users = num_users,
    num_items = train_ratings, data = train_ratings,
    all_itemIds = all_itemIds)


trained_model = NCF(num_users, num_items, train_ratings, all_orderIds).load_from_checkpoint(checkpoint_path = "NCF_Trained.ckpt")

但是这些方法似乎不起作用。我该如何加载已保存的检查点文件?

谢谢!


请问您说的“似乎不起作用”具体指什么?在使用了.load_from_checkpoint(...)之后,您想要做什么? - ayandas
这个问题解决了吗?@Arun - ayandas
暂未完成 @AyanDas - Arun
3个回答

2
在你的init方法中添加一行:
self.save_hyperparameters(logger=False)

然后调用。
trained_model = NCF.load_from_checkpoint("NCF_Trained.ckpt")

1

此文档中所示, load_from_checkpoint是在 PyTorch Lightning 中加载权重的主要方法,它会自动加载用于训练的超参数。因此,您不需要传递参数,除非要覆盖现有参数。我的建议是尝试使用trained_model = NCF.load_from_checkpoint("NCF_Trained.ckpt")


0
在我的情况下,通过model.eval()将模型设置为评估模式非常关键。否则它会产生错误的结果。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接