From 6ccea0486f3016d7ace7bd94da291840b64e235b Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 15 Dec 2020 09:51:12 -0500 Subject: [PATCH] Fix T5 model parallel tes (#9107) k --- tests/test_modeling_t5.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 9e229ced69..6188a8ffb5 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -484,9 +484,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () - all_parallelizable_model_classes = ( - (T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else () - ) + all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () test_pruning = False test_torchscript = True test_resize_embeddings = True @@ -689,6 +687,8 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False test_torchscript = True test_resize_embeddings = False + test_model_parallel = True + all_parallelizable_model_classes = (T5EncoderModel,) if is_torch_available() else () def setUp(self): self.model_tester = T5EncoderOnlyModelTester(self)