Unanswered
Hi All
Im Trying To Save My Model Checkpoints During Runtime But Am Running Into A Confusing Snag.
I'M Using The Huggingface Architecture For A Transformer. Using Their Training Module To Control Training. In The Training Args, I Have The
training function:
def training(checkpoint, image_processor):
data_test_train, labels, label_to_id, id_to_label = pre_process()
model = AutoModelForImageClassification.from_pretrained(
checkpoint,
num_labels=len(labels),
id2label=id_to_label,
label2id=label_to_id,
),
)
def metrics(eval_pred):
metric_val = config.get("eval_metric")
metric = evaluate.load(metric_val)
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
if metric_val == "accuracy":
return metric.compute(predictions=predictions, references=labels)
else:
return metric.compute(
predictions=predictions, references=labels, average="weighted"
)
data_collator = DefaultDataCollator()
training_args = TrainingArguments(
output_dir="somefolder",
remove_unused_columns=False,
eval_strategy="epoch",
save_strategy="epoch",
learning_rate=config.get("learning_rate"),
per_device_train_batch_size=config.get("train_batch_size"),
gradient_accumulation_steps=config.get("gradient_steps"),
per_device_eval_batch_size=config.get("eval_batch_size"),
num_train_epochs=config.get("epochs"),
warmup_ratio=config.get("warmup_ratio"),
logging_steps=config.get("logging_steps"),
save_total_limit=config.get("save_total"),
load_best_model_at_end=True,
report_to="tensorboard",
metric_for_best_model=config.get("eval_metric"),
push_to_hub=False,
)
trainer = Trainer(
model=model.to(device),
args=training_args,
data_collator=data_collator,
train_dataset=data_test_train["train"],
eval_dataset=data_test_train["test"],
tokenizer=image_processor,
compute_metrics=metrics,
)
print("training")
trainer.train()
print("save_model")
trainer.save_model()
best_model = trainer.state.best_model_checkpoint
print(best_model)
classifification_report(data_test_train["test"], model, best_model)
return best_model
33 Views
0
Answers
3 months ago
3 months ago