Fix seq2seq example test (#7518)

* Fix seq2seq example test

* Fix bad copy-paste

* Also save the state
This commit is contained in:
Sylvain Gugger
2020-10-01 14:13:29 -04:00
committed by GitHub
parent 29baa8fabe
commit bdcc4b78a2
3 changed files with 5 additions and 4 deletions

View File

@@ -601,6 +601,7 @@ class Trainer:
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir)
if self.is_world_master():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))