当模型具有多个输出时,ModelCheckpoint如何监测值

3

我的模型有两个输出,我想要监控其中一个以保存我的模型。 以下是我的代码的一部分。TensorFlow 的版本为 2.0。

model = MobileNetBaseModel()()
model.compile(optimizer=tf.keras.optimizers.Adam(),
              metrics={"pitch_yaw_roll": "mae"},
              loss={"pitch_yaw_roll": compute_mse_loss, # or "mse"
                    "total_logits": compute_cross_entropy_loss(num_classes=num_classes)},
              loss_weights= {"pitch_yaw_roll":mse_weight, "total_logits":cross_entropy_weight})
file_path = os.path.join(checkpoint_path, "model.{epoch:2d}-{val_loss:.2f}.h5")
tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
                                   monitor="val_loss",
                                   verbose=1,
                                   save_freq=save_freq,
                                   save_best_only=True)

默认情况下,在ModelCheckpoint回调函数中,monitor='val_loss',我该如何选择我需要的内容?我想监测 {"pitch_yaw_roll": "mae"}

1
你想要达到什么目标?你只想保存具有最低“pitch_yaw_roll”值的时期吗? - bluesummers
是的,也许我想要每几个批次保存对应最低值的模型。正如我所描述的,在tf.keras.callbacks.ModelCheckpoint中,我只能选择monitor = val_loss吗?感谢您的帮助!@bluesummers - Pandas
2个回答

2

如果您希望ModelCheckpoint根据其他指标保存结果,请在.compile(metrics={...}, ...)度量字典中使用该指标的键。

例如,如果您想仅保存最佳的"pitch_yaw_roll"时代结果(最佳值为最小值),则应使用:

tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
                                   monitor="val_pitch_yaw_roll",
                                   verbose=1,
                                   mode="min",
                                   save_freq=save_freq,
                                   save_best_only=True)

如果选择 "pitch_yaw_roll" 而不是 "val_pitch_yaw_roll",它将根据训练损失而不是验证损失进行保存。

1
根据您的意思,我知道如何按损失保存,那么如果我想使用验证集的mae而不是验证集的损失呢?就像monitor="val_picth_yaw_roll_mae"一样? - Pandas
就像我写的那样,不要包括 mae,只使用 val_pitch_yaw_roll - 因为这个键指向 mae,所监测的是 mae - bluesummers

0

仅补充上面的评论,我认为您的检查点不起作用是因为要监视的值名称不正确。 通常,在这里的解决方案可能是查看您的拟合创建的历史记录。

history = model.fit(...)
pd.DataFrame(history.history)

在那里,你会找到应该在监控语句中使用的指标名称。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接