From 384f0eb2f9d42e44094dbfd0917ccf4e6ddb462a Mon Sep 17 00:00:00 2001 From: Shaoyen Date: Mon, 18 May 2020 20:16:05 -0700 Subject: [PATCH] Map optimizer to correct device after loading from checkpoint. (#4403) * Map optimizer to correct device after loading from checkpoint. * Make style test pass Co-authored-by: Julien Chaumond --- src/transformers/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9aca17b8fc..1105a6009f 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -389,7 +389,9 @@ class Trainer: and os.path.isfile(os.path.join(model_path, "scheduler.pt")) ): # Load in optimizer and scheduler states - optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt"))) + optimizer.load_state_dict( + torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) + ) scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) model = self.model