From dee876cefffa769491a008670b35b5ac3192e929 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 17 Feb 2021 15:52:36 -0800 Subject: [PATCH] [trainer] refactor place_model_on_device logic, add deepspeed (#10243) * refactor place_model_on_device logic, add deepspeed * doc * style --- src/transformers/trainer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4054bc81c0..e1febc0b03 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -214,6 +214,10 @@ class Trainer: inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``. - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from data parallelism, this means some of the model layers are split on different GPUs). + - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set + to :obj:`False` if model parallel or deepspeed is used, or if the default + ``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` . + """ def __init__( @@ -262,6 +266,11 @@ class Trainer: else: self.is_model_parallel = False + # one place to sort out whether to place the model on device or not + self.place_model_on_device = args.place_model_on_device + if self.is_model_parallel or (args.deepspeed and args.do_train): + self.place_model_on_device = False + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset @@ -272,7 +281,7 @@ class Trainer: # 1. MP - since we are trying to fit a much bigger than 1 gpu model # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # and we only use deepspeed for training at the moment - if not (self.is_model_parallel or (args.deepspeed and args.do_train)) and self.args.place_model_on_device: + if self.place_model_on_device: model = model.to(args.device) # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs @@ -780,7 +789,7 @@ class Trainer: # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: - if not self.is_model_parallel and self.args.place_model_on_device: + if self.place_model_on_device: self.model = self.model.to(self.args.device) self.model_wrapped = self.model @@ -1033,7 +1042,7 @@ class Trainer: ) if isinstance(self.model, PreTrainedModel): self.model = self.model.from_pretrained(self.state.best_model_checkpoint) - if not self.is_model_parallel and self.args.place_model_on_device: + if self.place_model_on_device: self.model = self.model.to(self.args.device) else: state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))