Use model.from_pretrained for DataParallel also (#8795)

* Use model.from_pretrained for DataParallel also

When training on multiple GPUs, the code wraps a model with torch.nn.DataParallel. However if the model has custom from_pretrained logic, it does not get applied during load_best_model_at_end.

This commit uses the underlying model during load_best_model_at_end, and re-wraps the loaded model with DataParallel.

If you choose to reject this change, then could you please move the this logic to a function, e.g. def load_best_model_checkpoint(best_model_checkpoint) or something, so that it can be overridden?

* Fix silly bug

* Address review comments

Thanks for the feedback. I made the change that you proposed, but I also think we should update L811 to check if `self.mode` is an instance of `PreTrained`, otherwise we would still not get into that `if` section, right?
This commit is contained in:
Shai Erera
2020-11-30 18:11:10 +02:00
committed by GitHub
parent 4062c75e44
commit 773849415a

View File

@@ -808,8 +808,8 @@ class Trainer:
logger.info( logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
) )
if isinstance(model, PreTrainedModel): if isinstance(self.model, PreTrainedModel):
self.model = model.from_pretrained(self.state.best_model_checkpoint) self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if not self.args.model_parallel: if not self.args.model_parallel:
self.model = self.model.to(self.args.device) self.model = self.model.to(self.args.device)
else: else: