Add parallelization support for T5EncoderModel (#9082)

* add model parallelism to T5EncoderModel

add model parallelism to T5EncoderModel

* remove decoder from T5EncoderModel parallelize

* uodate T5EncoderModel docs

* Extend T5ModelTest for T5EncoderModel

* fix T5Stask using range for get_device_map

* fix style

Co-authored-by: Ahmed Elnaggar <elnaggar@rostlab.informatik.tu-muenchen.de>
This commit is contained in:
Ahmed Elnaggar
2020-12-14 18:00:45 +01:00
committed by GitHub
parent b00eb4fb02
commit a9c8bff724
3 changed files with 22 additions and 8 deletions

View File

@@ -485,12 +485,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
all_parallelizable_model_classes = (
(
T5Model,
T5ForConditionalGeneration,
)
if is_torch_available()
else ()
(T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else ()
)
test_pruning = False
test_torchscript = True