I have an issue with how clearml logs checkpoints.
We have a training setup with pytorch-lightning + clearml, where we use lightning.pytorch.ModelCheckpoint
for model checkpointing. Now, I would like to use clearml.OutputModel
s for storing the model configuration and weights, but when I'm just using the ModelCheckpoint callback with top_k=1 (saving the best checkpoint, but with varying filenames) lightning figures out to remove the old checkpoint (say epoch_001.ckpt
) and store the new one ( epoch_002.ckpt
) instead. In ClearML, these two checkpoints show up as separate OutputModels and our fileserver gets overloaded.
I have written this simple extension of ModelCheckpoint
that supports the single-checkpoint case, but I would rather it be something built into ClearML. Custom code is cool, but creates friction when you have to override lots of hidden functions
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from clearml import Task, OutputModel
class ClearMLModelCheckpoint(ModelCheckpoint):
"""
Callback that extends the functionality of the `ModelCheckpoint` callback
for saving the best model during training using ClearML.
Args:
The same as `ModelCheckpoint`.
Notes:
- Currently only supports saving a single model.
"""
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_train_start(trainer, pl_module)
task: Task = Task.current_task()
model_config = task.artifacts["configuration"].get()["model"]
self.output_model = OutputModel(task, config_dict=model_config)
def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
super()._save_checkpoint(trainer, filepath)
self.output_model.update_weights(filepath)
Has anyone been working on something similar? Found a good solution for model checkpointing that doesn't include a parallel implementation of ModelCheckpoint
?