Add parallelization support for T5EncoderModel (#9082)
* add model parallelism to T5EncoderModel add model parallelism to T5EncoderModel * remove decoder from T5EncoderModel parallelize * uodate T5EncoderModel docs * Extend T5ModelTest for T5EncoderModel * fix T5Stask using range for get_device_map * fix style Co-authored-by: Ahmed Elnaggar <elnaggar@rostlab.informatik.tu-muenchen.de>
This commit is contained in:
@@ -131,7 +131,7 @@ T5EncoderModel
|
|||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.T5EncoderModel
|
.. autoclass:: transformers.T5EncoderModel
|
||||||
:members: forward
|
:members: forward, parallelize, deparallelize
|
||||||
|
|
||||||
TFT5Model
|
TFT5Model
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|||||||
@@ -781,7 +781,7 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
def parallelize(self, device_map=None):
|
def parallelize(self, device_map=None):
|
||||||
# Check validity of device_map
|
# Check validity of device_map
|
||||||
self.device_map = (
|
self.device_map = (
|
||||||
get_device_map(len(self.block), torch.cuda.device_count()) if device_map is None else device_map
|
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
|
||||||
)
|
)
|
||||||
assert_device_map(self.device_map, len(self.block))
|
assert_device_map(self.device_map, len(self.block))
|
||||||
self.model_parallel = True
|
self.model_parallel = True
|
||||||
@@ -1579,6 +1579,25 @@ class T5EncoderModel(T5PreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||||
|
def parallelize(self, device_map=None):
|
||||||
|
self.device_map = (
|
||||||
|
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
||||||
|
if device_map is None
|
||||||
|
else device_map
|
||||||
|
)
|
||||||
|
assert_device_map(self.device_map, len(self.encoder.block))
|
||||||
|
self.encoder.parallelize(self.device_map)
|
||||||
|
self.model_parallel = True
|
||||||
|
|
||||||
|
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
||||||
|
def deparallelize(self):
|
||||||
|
self.encoder.deparallelize()
|
||||||
|
self.encoder = self.encoder.to("cpu")
|
||||||
|
self.model_parallel = False
|
||||||
|
self.device_map = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.shared
|
return self.shared
|
||||||
|
|
||||||
|
|||||||
@@ -485,12 +485,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||||
all_parallelizable_model_classes = (
|
all_parallelizable_model_classes = (
|
||||||
(
|
(T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else ()
|
||||||
T5Model,
|
|
||||||
T5ForConditionalGeneration,
|
|
||||||
)
|
|
||||||
if is_torch_available()
|
|
||||||
else ()
|
|
||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = True
|
test_torchscript = True
|
||||||
|
|||||||
Reference in New Issue
Block a user