From 0f443436fb1f8556c57c9113d0530988b26fb486 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 25 Jan 2021 17:12:07 +0100 Subject: [PATCH] Actual fix (#9787) --- src/transformers/models/gpt2/modeling_gpt2.py | 11 +++++++++++ src/transformers/models/t5/modeling_t5.py | 4 ++++ 2 files changed, 15 insertions(+) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index cebf705d71..eb4088b5e3 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -541,6 +541,7 @@ class GPT2Model(GPT2PreTrainedModel): self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.init_weights() + # Model parallel self.model_parallel = False self.device_map = None @@ -805,7 +806,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): self.init_weights() + # Model parallel self.model_parallel = False + self.device_map = None @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -971,6 +974,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + def get_output_embeddings(self): return self.lm_head @@ -1153,6 +1160,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 1eb35a2875..bd05cf00d1 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1651,6 +1651,10 @@ class T5EncoderModel(T5PreTrainedModel): self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): self.device_map = (