Fix test_model_parallelism (#25359)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -353,6 +353,7 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = CLIPTextModelTester(self)
|
||||
|
||||
@@ -308,6 +308,7 @@ class CLIPSegTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = CLIPSegTextModelTester(self)
|
||||
|
||||
@@ -388,6 +388,7 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
model_split_percents = [0.5, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Data2VecTextModelTester(self)
|
||||
|
||||
@@ -192,6 +192,7 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else {}
|
||||
)
|
||||
test_sequence_classification_problem_types = True
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = EsmModelTester(self)
|
||||
|
||||
@@ -323,6 +323,10 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
|
||||
def test_model_parallelism(self):
|
||||
super().test_model_parallelism()
|
||||
|
||||
|
||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||
|
||||
@@ -395,6 +395,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
else {}
|
||||
)
|
||||
fx_compatible = True
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RobertaModelTester(self)
|
||||
|
||||
@@ -395,6 +395,7 @@ class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, Pipe
|
||||
else {}
|
||||
)
|
||||
fx_compatible = False
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RobertaPreLayerNormModelTester(self)
|
||||
|
||||
@@ -235,6 +235,7 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
# ViltForMaskedLM, ViltForQuestionAnswering and ViltForImagesAndTextClassification require special treatment
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -163,6 +163,7 @@ class ViTHybridModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
model_split_percents = [0.5, 0.9]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ViTHybridModelTester(self)
|
||||
|
||||
@@ -347,6 +347,10 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
model = XGLMModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
|
||||
def test_model_parallelism(self):
|
||||
super().test_model_parallelism()
|
||||
|
||||
|
||||
@require_torch
|
||||
class XGLMModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user