Fix seq2seq example test (#7518)
* Fix seq2seq example test * Fix bad copy-paste * Also save the state
This commit is contained in:
@@ -4,11 +4,10 @@ import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.trainer_utils import TrainerState, set_seed
|
||||
|
||||
from .finetune_trainer import main
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
from .utils import load_json
|
||||
|
||||
|
||||
set_seed(42)
|
||||
@@ -17,7 +16,7 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
|
||||
def test_finetune_trainer():
|
||||
output_dir = run_trainer(1, "12", MBART_TINY, 1)
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
assert "eval_bleu" in first_step_stats
|
||||
@@ -30,7 +29,7 @@ def test_finetune_trainer_slow():
|
||||
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||
|
||||
# Check metrics
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
last_step_stats = eval_metrics[-1]
|
||||
|
||||
Reference in New Issue
Block a user