PyTorch:Dropout(?)导致训练+验证与仅训练模式下模型收敛不同

3
我们面临一个非常奇怪的问题。我们在两种不同的“执行”设置中测试完全相同的模型。在第一种情况下,给定一定数量的epochs,我们使用mini-batches进行一个epoch的训练,然后按照相同的标准在验证集上进行测试。然后,我们进入下一个epoch。显然,在每个训练epoch之前,我们都会使用model.train(),在验证之前我们打开model.eval()。
然后,我们采用完全相同的模型(相同的init、相同的数据集、相同的epochs等),仅在每个epoch之后进行训练而不进行验证。
仅看训练集表现,即使我们固定所有种子,这两种训练过程也会不同,并产生非常不同的指标结果(损失、准确率等)。具体来说,仅训练的过程性能较差。
我们还注意到以下事情:
- 这不是可重复性问题,因为对于相同的过程,多次执行会产生完全相同的结果(这是有意的); - 如果去除dropout,问题就会消失; - Batchnorm1d层在训练和评估之间仍然具有不同的行为,似乎正常工作; - 如果我们从TPUs转到CPUs,问题仍然存在。我们正在尝试Pythorch 1.6、Pythorch夜间版、XLA 1.6。
我们花了整整一天的时间来解决这个问题(不,我们不能避免使用dropout)。是否有人有任何关于如何解决这个事实的想法?
非常感谢!
附:这是用于在CPU上进行训练的代码。
def sigmoid(x):
    return 1 / (1 + torch.exp(-x))


def _run(model, EPOCHS, training_data_in, validation_data_in=None):
    
    def train_fn(train_dataloader, model, optimizer, criterion):

        running_loss = 0.
        running_accuracy = 0.
        running_tp = 0.
        running_tn = 0.
        running_fp = 0.
        running_fn = 0.
        
        model.train()

        for batch_idx, (ecg, spo2, labels) in enumerate(train_dataloader, 1):

            optimizer.zero_grad() 
                
            outputs = model(ecg)

            loss = criterion(outputs, labels)
                        
            loss.backward() # calculate the gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step() # update the network weights
                                                
            running_loss += loss.item()
            predicted = torch.round(sigmoid(outputs.data)) # here determining the sigmoid, not included in the model
            
            running_accuracy += (predicted == labels).sum().item() / labels.size(0)   
            
            fp = ((predicted - labels) == 1.).sum().item() 
            fn = ((predicted - labels) == -1.).sum().item()
            tp = ((predicted + labels) == 2.).sum().item()
            tn = ((predicted + labels) == 0.).sum().item()
            running_tp += tp
            running_fp += fp
            running_tn += tn
            running_fn += fn
            
        retval = {'loss':running_loss / batch_idx,
                'accuracy':running_accuracy / batch_idx,
                'tp':running_tp,
                'tn':running_tn,
                'fp':running_fp,
                'fn':running_fn
                }
            
        return retval
            

        
    def valid_fn(valid_dataloader, model, criterion):

        running_loss = 0.
        running_accuracy = 0.
        running_tp = 0.
        running_tn = 0.
        running_fp = 0.
        running_fn = 0.

        model.eval()
        
        for batch_idx, (ecg, spo2, labels) in enumerate(valid_dataloader, 1):

            outputs = model(ecg)

            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            predicted = torch.round(sigmoid(outputs.data)) # here determining the sigmoid, not included in the model

            running_accuracy += (predicted == labels).sum().item() / labels.size(0)  
            
            fp = ((predicted - labels) == 1.).sum().item()
            fn = ((predicted - labels) == -1.).sum().item()
            tp = ((predicted + labels) == 2.).sum().item()
            tn = ((predicted + labels) == 0.).sum().item()
            running_tp += tp
            running_fp += fp
            running_tn += tn
            running_fn += fn
            
        retval = {'loss':running_loss / batch_idx,
                'accuracy':running_accuracy / batch_idx,
                'tp':running_tp,
                'tn':running_tn,
                'fp':running_fp,
                'fn':running_fn
                }
            
        return retval
    
    
    
    # Defining data loaders

    train_dataloader = torch.utils.data.DataLoader(training_data_in, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
    
    if validation_data_in != None:
        validation_dataloader = torch.utils.data.DataLoader(validation_data_in, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)


    # Defining the loss function
    criterion = nn.BCEWithLogitsLoss()
    
    
    # Defining the optimizer
    import torch.optim as optim
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, amsgrad=False, eps=1e-07) 


    # Training code
    
    metrics_history = {"loss":[], "accuracy":[], "precision":[], "recall":[], "f1":[], "specificity":[], "accuracy_bis":[], "tp":[], "tn":[], "fp":[], "fn":[],
                    "val_loss":[], "val_accuracy":[], "val_precision":[], "val_recall":[], "val_f1":[], "val_specificity":[], "val_accuracy_bis":[], "val_tp":[], "val_tn":[], "val_fp":[], "val_fn":[],}
    
    train_begin = time.time()
    for epoch in range(EPOCHS):
        start = time.time()

        print("EPOCH:", epoch+1)

        train_metrics = train_fn(train_dataloader=train_dataloader, 
                                model=model,
                                optimizer=optimizer, 
                                criterion=criterion)
        
        metrics_history["loss"].append(train_metrics["loss"])
        metrics_history["accuracy"].append(train_metrics["accuracy"])
        metrics_history["tp"].append(train_metrics["tp"])
        metrics_history["tn"].append(train_metrics["tn"])
        metrics_history["fp"].append(train_metrics["fp"])
        metrics_history["fn"].append(train_metrics["fn"])
        
        precision = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fp"]) if train_metrics["tp"] > 0 else 0
        recall = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fn"]) if train_metrics["tp"] > 0 else 0
        specificity = train_metrics["tn"] / (train_metrics["tn"] + train_metrics["fp"]) if train_metrics["tn"] > 0 else 0
        f1 = 2*precision*recall / (precision + recall) if precision*recall > 0 else 0
        metrics_history["precision"].append(precision)
        metrics_history["recall"].append(recall)
        metrics_history["f1"].append(f1)
        metrics_history["specificity"].append(specificity)
        
        
        
        if validation_data_in != None:    
            # Calculate the metrics on the validation data, in the same way as done for training
            with torch.no_grad(): # don't keep track of the info necessary to calculate the gradients

                val_metrics = valid_fn(valid_dataloader=validation_dataloader, 
                                    model=model,
                                    criterion=criterion)

                metrics_history["val_loss"].append(val_metrics["loss"])
                metrics_history["val_accuracy"].append(val_metrics["accuracy"])
                metrics_history["val_tp"].append(val_metrics["tp"])
                metrics_history["val_tn"].append(val_metrics["tn"])
                metrics_history["val_fp"].append(val_metrics["fp"])
                metrics_history["val_fn"].append(val_metrics["fn"])

                val_precision = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fp"]) if val_metrics["tp"] > 0 else 0
                val_recall = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fn"]) if val_metrics["tp"] > 0 else 0
                val_specificity = val_metrics["tn"] / (val_metrics["tn"] + val_metrics["fp"]) if val_metrics["tn"] > 0 else 0
                val_f1 = 2*val_precision*val_recall / (val_precision + val_recall) if val_precision*val_recall > 0 else 0
                metrics_history["val_precision"].append(val_precision)
                metrics_history["val_recall"].append(val_recall)
                metrics_history["val_f1"].append(val_f1)
                metrics_history["val_specificity"].append(val_specificity)


            print("  > Training/validation loss:", round(train_metrics['loss'], 4), round(val_metrics['loss'], 4))
            print("  > Training/validation accuracy:", round(train_metrics['accuracy'], 4), round(val_metrics['accuracy'], 4))
            print("  > Training/validation precision:", round(precision, 4), round(val_precision, 4))
            print("  > Training/validation recall:", round(recall, 4), round(val_recall, 4))
            print("  > Training/validation f1:", round(f1, 4), round(val_f1, 4))
            print("  > Training/validation specificity:", round(specificity, 4), round(val_specificity, 4))
        else:
            print("  > Training loss:", round(train_metrics['loss'], 4))
            print("  > Training accuracy:", round(train_metrics['accuracy'], 4))
            print("  > Training precision:", round(precision, 4))
            print("  > Training recall:", round(recall, 4))
            print("  > Training f1:", round(f1, 4))
            print("  > Training specificity:", round(specificity, 4))


        print("Completed in:", round(time.time() - start, 1), "seconds \n")

    print("Training completed in:", round((time.time()- train_begin)/60, 1), "minutes")    

    
    
    # Save the model weights
    torch.save(model.state_dict(), './nnet_model.pt')
    
    
    # Save the metrics history
    torch.save(metrics_history, 'training_history')

这里是初始化模型并设置种子的函数,在每次执行“_run”代码之前调用:

def reinit_model():
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    net = Net() # the model
    return net

1
即使我们固定了所有的种子,也要设置代码如何处理种子。 - prosti
虽然我不认为这是与种子相关的问题,因为两种不同的执行方式都是可重现的。 - Andrea
1
对我来说,这似乎与权重初始化无关,因为他们说如果我们只看相同类型的执行(训练或训练+评估),执行是可重现的。 - nsacco
1个回答

2

好的,我找到了问题所在。

问题在于,显然运行评估时会更改一些随机种子,这会影响训练阶段。

因此解决方案如下:

  • 在函数“_run()”的开头,将所有随机种子状态设置为所需值,例如42。 然后将这些种子保存到磁盘。
  • 在函数“train_fn()”的开头,从磁盘读取种子状态,并设置它们
  • 在函数“train_fn()”的结尾处,将种子状态保存到磁盘

例如,在使用TPU和XLA运行时,必须使用以下指令:

  • 在函数“_run()”的开头: xm.set_rng_state(42), xm.save(xm.get_rng_state(), 'xm_seed')
  • 在函数“train_fn()”的开头: xm.set_rng_state(torch.load('xm_seed'), device=device) (您也可以使用 xm.master_print(xm.get_rng_state() 在此处打印种子以进行验证)
  • 在函数“train_fn_()” 的结尾处: xm.save(xm.get_rng_state(), 'xm_seed')

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接