[Trainer] move model to device before setting optimizer (#4450)
This commit is contained in:
@@ -188,7 +188,7 @@ class Trainer:
|
|||||||
prediction_loss_only:
|
prediction_loss_only:
|
||||||
(Optional) in evaluation and prediction, only return the loss
|
(Optional) in evaluation and prediction, only return the loss
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model.to(args.device)
|
||||||
self.args = args
|
self.args = args
|
||||||
if data_collator is not None:
|
if data_collator is not None:
|
||||||
self.data_collator = data_collator
|
self.data_collator = data_collator
|
||||||
@@ -393,7 +393,6 @@ class Trainer:
|
|||||||
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||||
|
|
||||||
model = self.model
|
model = self.model
|
||||||
model.to(self.args.device)
|
|
||||||
if self.args.fp16:
|
if self.args.fp16:
|
||||||
if not is_apex_available():
|
if not is_apex_available():
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
@@ -726,7 +725,6 @@ class Trainer:
|
|||||||
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
|
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
|
||||||
|
|
||||||
model = self.model
|
model = self.model
|
||||||
model.to(self.args.device)
|
|
||||||
# multi-gpu eval
|
# multi-gpu eval
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|||||||
Reference in New Issue
Block a user