[LongT5] disable model parallel test (#17702)

This commit is contained in:
Suraj Patil
2022-06-14 17:27:39 +02:00
committed by GitHub
parent 7ec9128e5a
commit 120649bf3a

View File

@@ -505,11 +505,10 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
all_model_classes = (LongT5Model, LongT5ForConditionalGeneration) if is_torch_available() else () all_model_classes = (LongT5Model, LongT5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (LongT5ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (LongT5ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False fx_compatible = False
all_parallelizable_model_classes = (LongT5Model, LongT5ForConditionalGeneration) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_torchscript = True test_torchscript = True
test_resize_embeddings = True test_resize_embeddings = True
test_model_parallel = True test_model_parallel = False
is_encoder_decoder = True is_encoder_decoder = True
def setUp(self): def setUp(self):
@@ -1013,8 +1012,7 @@ class LongT5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False test_pruning = False
test_torchscript = True test_torchscript = True
test_resize_embeddings = False test_resize_embeddings = False
test_model_parallel = True test_model_parallel = False
all_parallelizable_model_classes = (LongT5EncoderModel,) if is_torch_available() else ()
def setUp(self): def setUp(self):
self.model_tester = LongT5EncoderOnlyModelTester(self) self.model_tester = LongT5EncoderOnlyModelTester(self)