Map optimizer to correct device after loading from checkpoint. (#4403)
* Map optimizer to correct device after loading from checkpoint. * Make style test pass Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -389,7 +389,9 @@ class Trainer:
|
||||
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")))
|
||||
optimizer.load_state_dict(
|
||||
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
||||
)
|
||||
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
|
||||
model = self.model
|
||||
|
||||
Reference in New Issue
Block a user