From 505494a86ff5daa0ced517e7d5aef56cac49b3e5 Mon Sep 17 00:00:00 2001 From: Igor Shalyminov Date: Mon, 15 Mar 2021 13:10:44 +0000 Subject: [PATCH] GPT2DoubleHeadsModel made parallelizable (#10658) * GPT2DoubleHeadsModel made parallelizeable * GPT2DoubleHeadsModel added as parallelizeable onto the GPT2 test suite --- src/transformers/models/gpt2/modeling_gpt2.py | 27 +++++++++++++++++++ tests/test_modeling_gpt2.py | 2 +- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 4dd2c07509..4518964052 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -983,6 +983,28 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): self.model_parallel = False self.device_map = None + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + def get_output_embeddings(self): return self.lm_head @@ -1096,6 +1118,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): hidden_states = transformer_outputs[0] + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + lm_logits = self.lm_head(hidden_states) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 8385f9a2da..10c456d877 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -398,7 +398,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () - all_parallelizable_model_classes = (GPT2LMHeadModel,) if is_torch_available() else () + all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () test_missing_keys = False test_model_parallel = True