Allow passing arguments to model testers for CLIP-like models (#20044)
* POC * For more CLIP-like models Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -344,10 +344,16 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class CLIPModelTester:
|
class CLIPModelTester:
|
||||||
def __init__(self, parent, is_training=True):
|
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||||
|
|
||||||
|
if text_kwargs is None:
|
||||||
|
text_kwargs = {}
|
||||||
|
if vision_kwargs is None:
|
||||||
|
vision_kwargs = {}
|
||||||
|
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.text_model_tester = CLIPTextModelTester(parent)
|
self.text_model_tester = CLIPTextModelTester(parent, **text_kwargs)
|
||||||
self.vision_model_tester = CLIPVisionModelTester(parent)
|
self.vision_model_tester = CLIPVisionModelTester(parent, **vision_kwargs)
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
|
|||||||
@@ -746,17 +746,31 @@ class FlavaModelTester:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
|
text_kwargs=None,
|
||||||
|
image_kwargs=None,
|
||||||
|
multimodal_kwargs=None,
|
||||||
|
image_codebook_kwargs=None,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
hidden_size=32,
|
hidden_size=32,
|
||||||
projection_dim=32,
|
projection_dim=32,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-12,
|
layer_norm_eps=1e-12,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if text_kwargs is None:
|
||||||
|
text_kwargs = {}
|
||||||
|
if image_kwargs is None:
|
||||||
|
image_kwargs = {}
|
||||||
|
if multimodal_kwargs is None:
|
||||||
|
multimodal_kwargs = {}
|
||||||
|
if image_codebook_kwargs is None:
|
||||||
|
image_codebook_kwargs = {}
|
||||||
|
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.image_model_tester = FlavaImageModelTester(parent)
|
self.image_model_tester = FlavaImageModelTester(parent, **image_kwargs)
|
||||||
self.text_model_tester = FlavaTextModelTester(parent)
|
self.text_model_tester = FlavaTextModelTester(parent, **text_kwargs)
|
||||||
self.multimodal_model_tester = FlavaMultimodalModelTester(parent)
|
self.multimodal_model_tester = FlavaMultimodalModelTester(parent, **multimodal_kwargs)
|
||||||
self.image_codebook_tester = FlavaImageCodebookTester(parent)
|
self.image_codebook_tester = FlavaImageCodebookTester(parent, **image_codebook_kwargs)
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.config_tester = ConfigTester(self, config_class=FlavaConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=FlavaConfig, hidden_size=37)
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
|||||||
@@ -474,10 +474,16 @@ class GroupViTTextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class GroupViTModelTester:
|
class GroupViTModelTester:
|
||||||
def __init__(self, parent, is_training=True):
|
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||||
|
|
||||||
|
if text_kwargs is None:
|
||||||
|
text_kwargs = {}
|
||||||
|
if vision_kwargs is None:
|
||||||
|
vision_kwargs = {}
|
||||||
|
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.text_model_tester = GroupViTTextModelTester(parent)
|
self.text_model_tester = GroupViTTextModelTester(parent, **text_kwargs)
|
||||||
self.vision_model_tester = GroupViTVisionModelTester(parent)
|
self.vision_model_tester = GroupViTVisionModelTester(parent, **vision_kwargs)
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
|
|||||||
@@ -339,10 +339,16 @@ class OwlViTTextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class OwlViTModelTester:
|
class OwlViTModelTester:
|
||||||
def __init__(self, parent, is_training=True):
|
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||||
|
|
||||||
|
if text_kwargs is None:
|
||||||
|
text_kwargs = {}
|
||||||
|
if vision_kwargs is None:
|
||||||
|
vision_kwargs = {}
|
||||||
|
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.text_model_tester = OwlViTTextModelTester(parent)
|
self.text_model_tester = OwlViTTextModelTester(parent, **text_kwargs)
|
||||||
self.vision_model_tester = OwlViTVisionModelTester(parent)
|
self.vision_model_tester = OwlViTVisionModelTester(parent, **vision_kwargs)
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.text_config = self.text_model_tester.get_config().to_dict()
|
self.text_config = self.text_model_tester.get_config().to_dict()
|
||||||
self.vision_config = self.vision_model_tester.get_config().to_dict()
|
self.vision_config = self.vision_model_tester.get_config().to_dict()
|
||||||
|
|||||||
@@ -436,12 +436,26 @@ class XCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class XCLIPModelTester:
|
class XCLIPModelTester:
|
||||||
def __init__(self, parent, projection_dim=64, mit_hidden_size=64, is_training=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
text_kwargs=None,
|
||||||
|
vision_kwargs=None,
|
||||||
|
projection_dim=64,
|
||||||
|
mit_hidden_size=64,
|
||||||
|
is_training=True,
|
||||||
|
):
|
||||||
|
|
||||||
|
if text_kwargs is None:
|
||||||
|
text_kwargs = {}
|
||||||
|
if vision_kwargs is None:
|
||||||
|
vision_kwargs = {}
|
||||||
|
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.projection_dim = projection_dim
|
self.projection_dim = projection_dim
|
||||||
self.mit_hidden_size = mit_hidden_size
|
self.mit_hidden_size = mit_hidden_size
|
||||||
self.text_model_tester = XCLIPTextModelTester(parent)
|
self.text_model_tester = XCLIPTextModelTester(parent, **text_kwargs)
|
||||||
self.vision_model_tester = XCLIPVisionModelTester(parent)
|
self.vision_model_tester = XCLIPVisionModelTester(parent, **vision_kwargs)
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user