From 2d02178e5c053b2dac67d6ab61c1c4fc6c62ee57 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 4 Nov 2022 18:01:41 +0100 Subject: [PATCH] Allow passing arguments to model testers for CLIP-like models (#20044) * POC * For more CLIP-like models Co-authored-by: ydshieh --- tests/models/clip/test_modeling_clip.py | 12 +++++++--- tests/models/flava/test_modeling_flava.py | 22 +++++++++++++++---- .../models/groupvit/test_modeling_groupvit.py | 12 +++++++--- tests/models/owlvit/test_modeling_owlvit.py | 12 +++++++--- tests/models/x_clip/test_modeling_x_clip.py | 20 ++++++++++++++--- 5 files changed, 62 insertions(+), 16 deletions(-) diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index ab05f9adf1..524f002ad3 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -344,10 +344,16 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase): class CLIPModelTester: - def __init__(self, parent, is_training=True): + def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): + + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + self.parent = parent - self.text_model_tester = CLIPTextModelTester(parent) - self.vision_model_tester = CLIPVisionModelTester(parent) + self.text_model_tester = CLIPTextModelTester(parent, **text_kwargs) + self.vision_model_tester = CLIPVisionModelTester(parent, **vision_kwargs) self.is_training = is_training def prepare_config_and_inputs(self): diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index 62b89e3977..44aff1025f 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -746,17 +746,31 @@ class FlavaModelTester: def __init__( self, parent, + text_kwargs=None, + image_kwargs=None, + multimodal_kwargs=None, + image_codebook_kwargs=None, is_training=True, hidden_size=32, projection_dim=32, initializer_range=0.02, layer_norm_eps=1e-12, ): + + if text_kwargs is None: + text_kwargs = {} + if image_kwargs is None: + image_kwargs = {} + if multimodal_kwargs is None: + multimodal_kwargs = {} + if image_codebook_kwargs is None: + image_codebook_kwargs = {} + self.parent = parent - self.image_model_tester = FlavaImageModelTester(parent) - self.text_model_tester = FlavaTextModelTester(parent) - self.multimodal_model_tester = FlavaMultimodalModelTester(parent) - self.image_codebook_tester = FlavaImageCodebookTester(parent) + self.image_model_tester = FlavaImageModelTester(parent, **image_kwargs) + self.text_model_tester = FlavaTextModelTester(parent, **text_kwargs) + self.multimodal_model_tester = FlavaMultimodalModelTester(parent, **multimodal_kwargs) + self.image_codebook_tester = FlavaImageCodebookTester(parent, **image_codebook_kwargs) self.is_training = is_training self.config_tester = ConfigTester(self, config_class=FlavaConfig, hidden_size=37) self.hidden_size = hidden_size diff --git a/tests/models/groupvit/test_modeling_groupvit.py b/tests/models/groupvit/test_modeling_groupvit.py index 76b9d02a87..3b396daa67 100644 --- a/tests/models/groupvit/test_modeling_groupvit.py +++ b/tests/models/groupvit/test_modeling_groupvit.py @@ -474,10 +474,16 @@ class GroupViTTextModelTest(ModelTesterMixin, unittest.TestCase): class GroupViTModelTester: - def __init__(self, parent, is_training=True): + def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): + + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + self.parent = parent - self.text_model_tester = GroupViTTextModelTester(parent) - self.vision_model_tester = GroupViTVisionModelTester(parent) + self.text_model_tester = GroupViTTextModelTester(parent, **text_kwargs) + self.vision_model_tester = GroupViTVisionModelTester(parent, **vision_kwargs) self.is_training = is_training def prepare_config_and_inputs(self): diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py index e8f615ec8e..bb53d5c7cf 100644 --- a/tests/models/owlvit/test_modeling_owlvit.py +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -339,10 +339,16 @@ class OwlViTTextModelTest(ModelTesterMixin, unittest.TestCase): class OwlViTModelTester: - def __init__(self, parent, is_training=True): + def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): + + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + self.parent = parent - self.text_model_tester = OwlViTTextModelTester(parent) - self.vision_model_tester = OwlViTVisionModelTester(parent) + self.text_model_tester = OwlViTTextModelTester(parent, **text_kwargs) + self.vision_model_tester = OwlViTVisionModelTester(parent, **vision_kwargs) self.is_training = is_training self.text_config = self.text_model_tester.get_config().to_dict() self.vision_config = self.vision_model_tester.get_config().to_dict() diff --git a/tests/models/x_clip/test_modeling_x_clip.py b/tests/models/x_clip/test_modeling_x_clip.py index 0e9826d781..b4f3252e2f 100644 --- a/tests/models/x_clip/test_modeling_x_clip.py +++ b/tests/models/x_clip/test_modeling_x_clip.py @@ -436,12 +436,26 @@ class XCLIPTextModelTest(ModelTesterMixin, unittest.TestCase): class XCLIPModelTester: - def __init__(self, parent, projection_dim=64, mit_hidden_size=64, is_training=True): + def __init__( + self, + parent, + text_kwargs=None, + vision_kwargs=None, + projection_dim=64, + mit_hidden_size=64, + is_training=True, + ): + + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + self.parent = parent self.projection_dim = projection_dim self.mit_hidden_size = mit_hidden_size - self.text_model_tester = XCLIPTextModelTester(parent) - self.vision_model_tester = XCLIPVisionModelTester(parent) + self.text_model_tester = XCLIPTextModelTester(parent, **text_kwargs) + self.vision_model_tester = XCLIPVisionModelTester(parent, **vision_kwargs) self.is_training = is_training def prepare_config_and_inputs(self):