[finetune_trainer] enhancements and fixes (#9042)
* trainer and finetune_trainer enhancements and fixes * add fallback default * move the fixing of incorrect keys back into finetune trainer * s/eval/val/ to match the split * trainer can now use a different prefix than eval_ for metrics * document new arg * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * use 'eval' as the default for metric_key_prefix * complete adjust var names + disambiguate * fix logger * add clarifying comment * add clarifying comment * style * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/trainer.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * complete removal of optional for metric_key_prefix * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@@ -119,6 +120,46 @@ class DataTrainingArguments:
|
||||
)
|
||||
|
||||
|
||||
def speed_metrics(split, start_time, num_samples):
|
||||
"""
|
||||
Measure and return speed performance metrics.
|
||||
|
||||
This function requires a time snapshot `start_time` before the operation to be measured starts and this
|
||||
function should be run immediately after the operation to be measured has completed.
|
||||
|
||||
Args:
|
||||
- split: one of train, val, test
|
||||
- start_time: operation start time
|
||||
- num_samples: number of samples processed
|
||||
|
||||
"""
|
||||
runtime = time.time() - start_time
|
||||
result = {}
|
||||
|
||||
samples_per_second = 1 / (runtime / num_samples)
|
||||
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
|
||||
result[f"{split}_runtime"] = round(runtime, 4)
|
||||
|
||||
result[f"{split}_n_ojbs"] = num_samples
|
||||
return result
|
||||
|
||||
|
||||
def handle_metrics(split, metrics, output_dir):
|
||||
"""
|
||||
Log and save metrics
|
||||
|
||||
Args:
|
||||
- split: one of train, val, test
|
||||
- metrics: metrics dict
|
||||
- output_dir: where to save the metrics
|
||||
"""
|
||||
|
||||
logger.info(f"***** {split} metrics *****")
|
||||
for key, value in metrics.items():
|
||||
logger.info(f" {key} = {value}")
|
||||
save_json(metrics, os.path.join(output_dir, f"{split}_results.json"))
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
@@ -265,45 +306,56 @@ def main():
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
all_metrics = {}
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
logger.info("*** Train ***")
|
||||
|
||||
start_time = time.time()
|
||||
trainer.train(
|
||||
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
|
||||
)
|
||||
trainer.save_model()
|
||||
# For convenience, we also re-save the tokenizer to the same directory,
|
||||
# so that you can share your model easily on huggingface.co/models =)
|
||||
metrics = speed_metrics("train", start_time, data_args.n_train)
|
||||
|
||||
trainer.save_model() # this also saves the tokenizer
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
handle_metrics("train", metrics, training_args.output_dir)
|
||||
all_metrics.update(metrics)
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
|
||||
# For convenience, we also re-save the tokenizer to the same directory,
|
||||
# so that you can share your model easily on huggingface.co/models =)
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
# Evaluation
|
||||
eval_results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
result = trainer.evaluate()
|
||||
start_time = time.time()
|
||||
metrics = trainer.evaluate(metric_key_prefix="val")
|
||||
metrics.update(speed_metrics("val", start_time, data_args.n_val))
|
||||
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
save_json(result, os.path.join(training_args.output_dir, "eval_results.json"))
|
||||
eval_results.update(result)
|
||||
|
||||
handle_metrics("val", metrics, training_args.output_dir)
|
||||
all_metrics.update(metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logging.info("*** Test ***")
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
test_output = trainer.predict(test_dataset=test_dataset)
|
||||
test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()}
|
||||
start_time = time.time()
|
||||
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
|
||||
metrics = test_output.metrics
|
||||
metrics.update(speed_metrics("test", start_time, data_args.n_test))
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
logger.info("***** Test results *****")
|
||||
for key, value in test_metrics.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
|
||||
save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json"))
|
||||
eval_results.update(test_metrics)
|
||||
metrics["test_loss"] = round(metrics["test_loss"], 4)
|
||||
handle_metrics("test", metrics, training_args.output_dir)
|
||||
all_metrics.update(metrics)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
test_preds = tokenizer.batch_decode(
|
||||
@@ -313,8 +365,9 @@ def main():
|
||||
write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt"))
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
save_json(eval_results, "all_results.json")
|
||||
return eval_results
|
||||
save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json"))
|
||||
|
||||
return all_metrics
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
||||
@@ -462,7 +462,7 @@ def save_git_info(folder_path: str) -> None:
|
||||
|
||||
def save_json(content, path, indent=4, **json_dump_kwargs):
|
||||
with open(path, "w") as f:
|
||||
json.dump(content, f, indent=indent, **json_dump_kwargs)
|
||||
json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
|
||||
|
||||
|
||||
def load_json(path):
|
||||
|
||||
Reference in New Issue
Block a user