[CLIP] allow loading projection layer in vision and text model (#18962)
* allow loading projection in text and vision model * begin tests * finish test for CLIPTextModelTest * style * add slow tests * add new classes for projection heads * remove with_projection * add in init * add in doc * fix tests * fix some more tests * fix copies * fix docs * remove leftover from fix-copies * add the head models in IGNORE_NON_AUTO_CONFIGURED * fix docstr * fix tests * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * add docstr for models Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -49,7 +49,13 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import CLIPModel, CLIPTextModel, CLIPVisionModel
|
||||
from transformers import (
|
||||
CLIPModel,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPVisionModel,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
from transformers.models.clip.modeling_clip import CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@@ -77,6 +83,7 @@ class CLIPVisionModelTester:
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
@@ -92,6 +99,7 @@ class CLIPVisionModelTester:
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
@@ -116,6 +124,7 @@ class CLIPVisionModelTester:
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
@@ -137,6 +146,19 @@ class CLIPVisionModelTester:
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_model_with_projection(self, config, pixel_values):
|
||||
model = CLIPVisionModelWithProjection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = (self.image_size, self.image_size)
|
||||
patch_size = (self.patch_size, self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
self.parent.assertEqual(result.image_embeds.shape, (self.batch_size, self.projection_dim))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
@@ -151,7 +173,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
|
||||
all_model_classes = (CLIPVisionModel, CLIPVisionModelWithProjection) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
@@ -193,6 +215,10 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_with_projection(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_projection(*config_and_inputs)
|
||||
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@@ -213,6 +239,13 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = CLIPVisionModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_model_with_projection_from_pretrained(self):
|
||||
for model_name in CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = CLIPVisionModelWithProjection.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertTrue(hasattr(model, "visual_projection"))
|
||||
|
||||
|
||||
class CLIPTextModelTester:
|
||||
def __init__(
|
||||
@@ -225,6 +258,7 @@ class CLIPTextModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
@@ -242,6 +276,7 @@ class CLIPTextModelTester:
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
@@ -273,6 +308,7 @@ class CLIPTextModelTester:
|
||||
return CLIPTextConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
@@ -292,6 +328,16 @@ class CLIPTextModelTester:
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_model_with_projection(self, config, input_ids, input_mask):
|
||||
model = CLIPTextModelWithProjection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.text_embeds.shape, (self.batch_size, self.projection_dim))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, input_mask = config_and_inputs
|
||||
@@ -302,7 +348,7 @@ class CLIPTextModelTester:
|
||||
@require_torch
|
||||
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (CLIPTextModel,) if is_torch_available() else ()
|
||||
all_model_classes = (CLIPTextModel, CLIPTextModelWithProjection) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
@@ -318,6 +364,10 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_with_projection(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_projection(*config_and_inputs)
|
||||
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@@ -342,6 +392,13 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = CLIPTextModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_model_with_projection_from_pretrained(self):
|
||||
for model_name in CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = CLIPTextModelWithProjection.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertTrue(hasattr(model, "text_projection"))
|
||||
|
||||
|
||||
class CLIPModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||
|
||||
Reference in New Issue
Block a user