[examples] better model example (#10427)

* refactors

* typo
This commit is contained in:
Stas Bekman
2021-02-26 17:01:01 -08:00
committed by GitHub
parent a85eb616f7
commit ee04b69822
3 changed files with 46 additions and 20 deletions

View File

@@ -572,7 +572,6 @@ def main():
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
all_metrics = {}
# Training
if training_args.do_train:
if last_checkpoint is not None:
@@ -589,13 +588,10 @@ def main():
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
if trainer.is_world_process_zero():
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
all_metrics.update(metrics)
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
results = {}
@@ -608,10 +604,8 @@ def main():
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
if trainer.is_world_process_zero():
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
all_metrics.update(metrics)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_predict:
logger.info("*** Test ***")
@@ -626,11 +620,10 @@ def main():
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
if trainer.is_world_process_zero():
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
all_metrics.update(metrics)
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
if trainer.is_world_process_zero():
if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
@@ -640,9 +633,6 @@ def main():
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds))
if trainer.is_world_process_zero():
trainer.save_metrics("all", metrics)
return results