diff --git a/docs/source/model_doc/t5.rst b/docs/source/model_doc/t5.rst index 60a7b58492..07592ff347 100644 --- a/docs/source/model_doc/t5.rst +++ b/docs/source/model_doc/t5.rst @@ -131,7 +131,7 @@ T5EncoderModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.T5EncoderModel - :members: forward + :members: forward, parallelize, deparallelize TFT5Model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index f9b77aed32..491ece4d99 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -781,7 +781,7 @@ class T5Stack(T5PreTrainedModel): def parallelize(self, device_map=None): # Check validity of device_map self.device_map = ( - get_device_map(len(self.block), torch.cuda.device_count()) if device_map is None else device_map + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map ) assert_device_map(self.device_map, len(self.block)) self.model_parallel = True @@ -1579,6 +1579,25 @@ class T5EncoderModel(T5PreTrainedModel): self.init_weights() + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + def get_input_embeddings(self): return self.shared diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 15d5866c91..9e229ced69 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -485,12 +485,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () all_parallelizable_model_classes = ( - ( - T5Model, - T5ForConditionalGeneration, - ) - if is_torch_available() - else () + (T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else () ) test_pruning = False test_torchscript = True