Fix seq2seq example test (#7518)
* Fix seq2seq example test * Fix bad copy-paste * Also save the state
This commit is contained in:
@@ -276,6 +276,7 @@ def main():
|
||||
# For convenience, we also re-save the tokenizer to the same directory,
|
||||
# so that you can share your model easily on huggingface.co/models =)
|
||||
if trainer.is_world_process_zero():
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
# Evaluation
|
||||
|
||||
@@ -4,11 +4,10 @@ import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.trainer_utils import TrainerState, set_seed
|
||||
|
||||
from .finetune_trainer import main
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
from .utils import load_json
|
||||
|
||||
|
||||
set_seed(42)
|
||||
@@ -17,7 +16,7 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
|
||||
def test_finetune_trainer():
|
||||
output_dir = run_trainer(1, "12", MBART_TINY, 1)
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
assert "eval_bleu" in first_step_stats
|
||||
@@ -30,7 +29,7 @@ def test_finetune_trainer_slow():
|
||||
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||
|
||||
# Check metrics
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
last_step_stats = eval_metrics[-1]
|
||||
|
||||
@@ -601,6 +601,7 @@ class Trainer:
|
||||
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||
self.save_model(output_dir)
|
||||
if self.is_world_master():
|
||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user