Use model.from_pretrained for DataParallel also (#8795)
* Use model.from_pretrained for DataParallel also When training on multiple GPUs, the code wraps a model with torch.nn.DataParallel. However if the model has custom from_pretrained logic, it does not get applied during load_best_model_at_end. This commit uses the underlying model during load_best_model_at_end, and re-wraps the loaded model with DataParallel. If you choose to reject this change, then could you please move the this logic to a function, e.g. def load_best_model_checkpoint(best_model_checkpoint) or something, so that it can be overridden? * Fix silly bug * Address review comments Thanks for the feedback. I made the change that you proposed, but I also think we should update L811 to check if `self.mode` is an instance of `PreTrained`, otherwise we would still not get into that `if` section, right?
This commit is contained in:
@@ -808,8 +808,8 @@ class Trainer:
|
||||
logger.info(
|
||||
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
|
||||
)
|
||||
if isinstance(model, PreTrainedModel):
|
||||
self.model = model.from_pretrained(self.state.best_model_checkpoint)
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
|
||||
if not self.args.model_parallel:
|
||||
self.model = self.model.to(self.args.device)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user