ElegantKangaroo44 good question, that depends on where we store the score of the model itself. you can obviously parse the file name task.models['output'][-1].url
and retrieve the score from it. you can also store it on the model name task.models['output'][-1].name
and you can put it as general purpose blob o text on what is currently model.config_text
(for convenience you can have model parse a json like text and use model.config_dict
It all depends how we store the meta-data on the performance. You could actually retrieve it from the say val metric and deduce the epoch based on that
I'd prefer to use config_dict, I think it's cleaner (as a workaround for metadata). However, since I'm using ignite, I think I have no way to actually do that, or at least i'm not aware of it.
But I think that whatever way one chooses, you will have to go through N best models right after training and find the best one (because of the issue we're discussing on ignite).
I think ideal would be one of the two:
Only store a single model_best
. After training you just find the model with that name in task.models["output"]
Store as many model_best
as you want (with n_saved
from ignite). Everytime a new best_model
is saved, add a tag best
, and remove the tag from the previous best models (not sure how straightforward that is though)So one can just do at the end best = task.get_model_best
After that you could also programmatically check against the accuracy of your production model (using metadata / json config) and decide what youwant to do with the model .
what do you think?
I'd prefer to use config_dict, I think it's cleaner
I'm definitely with you
Good news:
new
best_model
is saved, add a tag
best
,
Already supported, (you just can't see the tag, but it is there :))
My question is, what do you think would be the easiest interface to tell (post/pre) store, tag/mark this model as best so far (btw, obviously if we know it's not good, why do we bother to store it in the first place...)
Good news:
new
best_model
is saved, add a tag
best
,
Already supported, (you just can't see the tag, but it is there :))
Interesting! Could you point me where the tagging happens? Also, by "see" do you mean UI? I tried doing task.models["outputs"][-1].tag
but there's no property tag ( AttributeError: 'Model' object has no attribute 'tag'
, seems that only OutputModel
s have the tag property?)
I'll answer specifically for ignite on your question regarding the interface:
Doing something like passing the tag to TrainSaver
seems convenient. Same for the config / metadata etcTrainsSaver(output_uri, tag="best", metadata={})
(btw, obviously if we know it's not good, why do we bother to store it in the first place...)
During training you find a new best model every some epochs, so you save that. I agree no need to save all the best models, only one is enough (that's why I have n_saved=1
in my ignite checkpoints for best models). So you save a best_model
, tag it, add metadata to it. Then you find a new best_model
, and overwrite the previous.
But in ignite they give the option to keep n_saved
so you would need to remove the "best" tag from the previously saved best models and only add it on the last one.
I hope this makes sense 😄
task.models["outputs"][-1].tags
(plural, a list of strings) and yes I mean the UI 🙂
I get the n_saved
what's missing for me is how would you tell the TrainsLogger/Trains the current one is the best? Or are we assuming the last saved model is always the best ? (in that case there is no need for tag, you just take the last in the list)
If we are going with: "I'm only saving the model if it is better than the previous checkpoint" then just always use the same name i.e. " http://model.pt " Trains will override the entry and you will have a single output model (this is the same idea that was suggest with n_saved, basically model_{:d}.pt.format(counter % n_saved)
You see what I mean by "telling" ? the automagic will log the "torch.save" call, but we need a way to signal to it (before or after the call) that this is the best model.
So, using ignite I do the following:task.phases['valid'].add_event_handler( Events.EPOCH_COMPLETED(every=1), Checkpoint(to_save, TrainsSaver(output_uri=self.checkpoint_path), 'best', n_saved=1, score_function=lambda x: task.phases['valid'].state.metrics[self.monitor] if self.monitor_mode == 'max' else -task.phases['valid'].state.metrics[self.monitor], score_name=self.monitor))
This means as you said: The last model saved with this checkpointing is the best model.
But since I passed a score_function, ignite will automatically append the value of it on the suffix, and I end up with: checkpoint_best_acc=
http://0.9.pt . I guess it's also implied that when you use the score_function
parameter, you're saving your best model (why would you use a score otherwise?).
However, as we discussed in the issue in the ignite repo, there is no way to have a checkpoint that is simply named checkpoint_best
at this moment. So, if I am to use the ignite Checkpoint to save the models, I have no power to change the suffix as an end user. But when I use it together with trains, I end up uploading to output_uri
all the best models ever saved, regardless of me defining n_saved=1
(because of the issue we're discussing in ignite)
Now on top of that I also havetask.phases['train'].add_event_handler( Events.EPOCH_COMPLETED(every=self.save_freq), Checkpoint(to_save, TrainsSaver(output_uri=self.checkpoint_path), 'epoch', n_saved=5, global_step_transform=global_step_from_engine(task.phases['train'])))
That saves checkpoints every 20 epochs as a backup (regardless of their accuracy).
That means, that I have no knowledge as to what type of checkpoint was the last torch.save
that was called. It might be a backup, it might be a best model it might be a backup every 20 epochs.
Based on all that discussion I'm actually thinking maybe for Trains and ignite integration, there should be a TrainsCheckpoint
instead of a TrainsSaver
, that takes care of everything properly with respect to trains. But not sure if that's something the ignite ppl would want.
WDYT?
ElegantKangaroo44 I think TrainsCheckpoint
would probably be the easiest solution. I mean it will not be a must, but another option to deepen the integration, and allow us more flexibility.