Map optimizer to correct device after loading from checkpoint. (#4403)

* Map optimizer to correct device after loading from checkpoint.

* Make style test pass

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Shaoyen
2020-05-18 20:16:05 -07:00
committed by GitHub
parent bf14ef75f1
commit 384f0eb2f9

View File

@@ -389,7 +389,9 @@ class Trainer:
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")))
optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
model = self.model