[Trainer] move model to device before setting optimizer (#4450)
This commit is contained in:
@@ -188,7 +188,7 @@ class Trainer:
|
||||
prediction_loss_only:
|
||||
(Optional) in evaluation and prediction, only return the loss
|
||||
"""
|
||||
self.model = model
|
||||
self.model = model.to(args.device)
|
||||
self.args = args
|
||||
if data_collator is not None:
|
||||
self.data_collator = data_collator
|
||||
@@ -393,7 +393,6 @@ class Trainer:
|
||||
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
|
||||
model = self.model
|
||||
model.to(self.args.device)
|
||||
if self.args.fp16:
|
||||
if not is_apex_available():
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
@@ -726,7 +725,6 @@ class Trainer:
|
||||
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
|
||||
|
||||
model = self.model
|
||||
model.to(self.args.device)
|
||||
# multi-gpu eval
|
||||
if self.args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
Reference in New Issue
Block a user