Add parallelization support for T5EncoderModel (#9082)

* add model parallelism to T5EncoderModel

add model parallelism to T5EncoderModel

* remove decoder from T5EncoderModel parallelize

* uodate T5EncoderModel docs

* Extend T5ModelTest for T5EncoderModel

* fix T5Stask using range for get_device_map

* fix style

Co-authored-by: Ahmed Elnaggar <elnaggar@rostlab.informatik.tu-muenchen.de>
This commit is contained in:
Ahmed Elnaggar
2020-12-14 18:00:45 +01:00
committed by GitHub
parent b00eb4fb02
commit a9c8bff724
3 changed files with 22 additions and 8 deletions

View File

@@ -131,7 +131,7 @@ T5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.T5EncoderModel .. autoclass:: transformers.T5EncoderModel
:members: forward :members: forward, parallelize, deparallelize
TFT5Model TFT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@@ -781,7 +781,7 @@ class T5Stack(T5PreTrainedModel):
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
# Check validity of device_map # Check validity of device_map
self.device_map = ( self.device_map = (
get_device_map(len(self.block), torch.cuda.device_count()) if device_map is None else device_map get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
) )
assert_device_map(self.device_map, len(self.block)) assert_device_map(self.device_map, len(self.block))
self.model_parallel = True self.model_parallel = True
@@ -1579,6 +1579,25 @@ class T5EncoderModel(T5PreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.shared

View File

@@ -485,12 +485,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
all_parallelizable_model_classes = ( all_parallelizable_model_classes = (
( (T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else ()
T5Model,
T5ForConditionalGeneration,
)
if is_torch_available()
else ()
) )
test_pruning = False test_pruning = False
test_torchscript = True test_torchscript = True