Fix seq2seq example test (#7518)
* Fix seq2seq example test * Fix bad copy-paste * Also save the state
This commit is contained in:
@@ -276,6 +276,7 @@ def main():
|
|||||||
# For convenience, we also re-save the tokenizer to the same directory,
|
# For convenience, we also re-save the tokenizer to the same directory,
|
||||||
# so that you can share your model easily on huggingface.co/models =)
|
# so that you can share your model easily on huggingface.co/models =)
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
|
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
|
|||||||
@@ -4,11 +4,10 @@ import tempfile
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import slow
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import TrainerState, set_seed
|
||||||
|
|
||||||
from .finetune_trainer import main
|
from .finetune_trainer import main
|
||||||
from .test_seq2seq_examples import MBART_TINY
|
from .test_seq2seq_examples import MBART_TINY
|
||||||
from .utils import load_json
|
|
||||||
|
|
||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
@@ -17,7 +16,7 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
|||||||
|
|
||||||
def test_finetune_trainer():
|
def test_finetune_trainer():
|
||||||
output_dir = run_trainer(1, "12", MBART_TINY, 1)
|
output_dir = run_trainer(1, "12", MBART_TINY, 1)
|
||||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
first_step_stats = eval_metrics[0]
|
first_step_stats = eval_metrics[0]
|
||||||
assert "eval_bleu" in first_step_stats
|
assert "eval_bleu" in first_step_stats
|
||||||
@@ -30,7 +29,7 @@ def test_finetune_trainer_slow():
|
|||||||
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
|
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||||
|
|
||||||
# Check metrics
|
# Check metrics
|
||||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
first_step_stats = eval_metrics[0]
|
first_step_stats = eval_metrics[0]
|
||||||
last_step_stats = eval_metrics[-1]
|
last_step_stats = eval_metrics[-1]
|
||||||
|
|||||||
@@ -601,6 +601,7 @@ class Trainer:
|
|||||||
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||||
self.save_model(output_dir)
|
self.save_model(output_dir)
|
||||||
if self.is_world_master():
|
if self.is_world_master():
|
||||||
|
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user