From 773849415aff693d528a28617441284c0e541a6f Mon Sep 17 00:00:00 2001 From: Shai Erera Date: Mon, 30 Nov 2020 18:11:10 +0200 Subject: [PATCH] Use model.from_pretrained for DataParallel also (#8795) * Use model.from_pretrained for DataParallel also When training on multiple GPUs, the code wraps a model with torch.nn.DataParallel. However if the model has custom from_pretrained logic, it does not get applied during load_best_model_at_end. This commit uses the underlying model during load_best_model_at_end, and re-wraps the loaded model with DataParallel. If you choose to reject this change, then could you please move the this logic to a function, e.g. def load_best_model_checkpoint(best_model_checkpoint) or something, so that it can be overridden? * Fix silly bug * Address review comments Thanks for the feedback. I made the change that you proposed, but I also think we should update L811 to check if `self.mode` is an instance of `PreTrained`, otherwise we would still not get into that `if` section, right? --- src/transformers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 27372a9b50..f50b96e455 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -808,8 +808,8 @@ class Trainer: logger.info( f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." ) - if isinstance(model, PreTrainedModel): - self.model = model.from_pretrained(self.state.best_model_checkpoint) + if isinstance(self.model, PreTrainedModel): + self.model = self.model.from_pretrained(self.state.best_model_checkpoint) if not self.args.model_parallel: self.model = self.model.to(self.args.device) else: