Fix seq2seq example test (#7518)

* Fix seq2seq example test

* Fix bad copy-paste

* Also save the state
This commit is contained in:
Sylvain Gugger
2020-10-01 14:13:29 -04:00
committed by GitHub
parent 29baa8fabe
commit bdcc4b78a2
3 changed files with 5 additions and 4 deletions

View File

@@ -4,11 +4,10 @@ import tempfile
from unittest.mock import patch
from transformers.testing_utils import slow
from transformers.trainer_utils import set_seed
from transformers.trainer_utils import TrainerState, set_seed
from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY
from .utils import load_json
set_seed(42)
@@ -17,7 +16,7 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
def test_finetune_trainer():
output_dir = run_trainer(1, "12", MBART_TINY, 1)
logs = load_json(os.path.join(output_dir, "log_history.json"))
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats
@@ -30,7 +29,7 @@ def test_finetune_trainer_slow():
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
# Check metrics
logs = load_json(os.path.join(output_dir, "log_history.json"))
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1]