I think the model state is just post training loop (not inside the loop), no?
trainer_state.json gets updated every time a "checkpoint" gets saved. I've got that set to once an epoch.
My testing indicates that if the training gets interrupted, I can resume training from a saved checkpoint folder that includes trainer_state.json. It uses the info to determine which data to skip, where to pick back up again, etc
Which is defined, it seems, here: https://github.com/huggingface/transformers/blob/040283170cd559b59b8eb37fe9fe8e99ff7edcbc/src/transformers/trainer_tf.py#L459
If you cannot change the "TrainerState" (i.e. inherit and pass it into the code)
you cloud also monkey-patch it, something like
` class OurTrainerState(TrainerState):
def load_from_json(cls, json_path: str):
trainer.state = OurTrainerState(trainer.state) `
OK, I added
Task.current_task().upload_artifact(name='trainer_state', artifact_object=os.path.join(output_dir, "trainer_state.json"))
after this line:
And it seems to be working.
OK, neat! Any advice on how to edit the training loop to do that? Because the code I'm using doesn't offer easy access to the training loop, see here: https://github.com/huggingface/transformers/blob/040283170cd559b59b8eb37fe9fe8e99ff7edcbc/examples/pytorch/language-modeling/run_mlm.py#L469
trainer.train() just does the training loop automagically, and saves a checkpoint once in a while. When it saves a checkpoint, clearML uploads all the other files. How can I hook into... whatever triggers that, and upload this file also?
specifically called here:
Maybe after this line add:
Task.current_task().upload_artifact('trainer_state.json, name='state') `wdyt?
The any generally any pytorch.save(...) is logged/uploaded by
clearml automatically. specifically in your case I think the only missing one is the trainer_sate.json, which I assume is general json file, and I imagine is part of huggingface framework. You can easily upload it as additional artifact with
Could I use "register artifact"
I think this is somewhat deprecated and we should probably replace it with something similar to what you mentioned (i.e. watch a file change).
Right now the easiest way would e to manually upload the
trainer_state.json every checkpoint:
Task.current_task().upload_artifact('trainer_state.json, name='state') `
Sorry, you are correct this is where the json is created:
other links are the function calling it. make sense ?