[CLIP] allow loading projection layer in vision and text model (#18962)

* allow loading projection in text and vision model

* begin tests

* finish test for CLIPTextModelTest

* style

* add slow tests

* add new classes for projection heads

* remove with_projection

* add in init

* add in doc

* fix tests

* fix some more tests

* fix copies

* fix docs

* remove leftover from fix-copies

* add the head models in IGNORE_NON_AUTO_CONFIGURED

* fix docstr

* fix tests

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* add docstr for models

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2022-11-15 17:50:07 +01:00
committed by GitHub
parent 9643ecf8ca
commit 7f74433814
10 changed files with 347 additions and 7 deletions

View File

@@ -49,7 +49,13 @@ if is_torch_available():
import torch
from torch import nn
from transformers import CLIPModel, CLIPTextModel, CLIPVisionModel
from transformers import (
CLIPModel,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModel,
CLIPVisionModelWithProjection,
)
from transformers.models.clip.modeling_clip import CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -77,6 +83,7 @@ class CLIPVisionModelTester:
num_channels=3,
is_training=True,
hidden_size=32,
projection_dim=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
@@ -92,6 +99,7 @@ class CLIPVisionModelTester:
self.num_channels = num_channels
self.is_training = is_training
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
@@ -116,6 +124,7 @@ class CLIPVisionModelTester:
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
@@ -137,6 +146,19 @@ class CLIPVisionModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_model_with_projection(self, config, pixel_values):
model = CLIPVisionModelWithProjection(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(result.image_embeds.shape, (self.batch_size, self.projection_dim))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
@@ -151,7 +173,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
attention_mask and seq_length.
"""
all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
all_model_classes = (CLIPVisionModel, CLIPVisionModelWithProjection) if is_torch_available() else ()
fx_compatible = True
test_pruning = False
test_resize_embeddings = False
@@ -193,6 +215,10 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_projection(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_projection(*config_and_inputs)
def test_training(self):
pass
@@ -213,6 +239,13 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
model = CLIPVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@slow
def test_model_with_projection_from_pretrained(self):
for model_name in CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = CLIPVisionModelWithProjection.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "visual_projection"))
class CLIPTextModelTester:
def __init__(
@@ -225,6 +258,7 @@ class CLIPTextModelTester:
use_labels=True,
vocab_size=99,
hidden_size=32,
projection_dim=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
@@ -242,6 +276,7 @@ class CLIPTextModelTester:
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
@@ -273,6 +308,7 @@ class CLIPTextModelTester:
return CLIPTextConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
@@ -292,6 +328,16 @@ class CLIPTextModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_model_with_projection(self, config, input_ids, input_mask):
model = CLIPTextModelWithProjection(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.text_embeds.shape, (self.batch_size, self.projection_dim))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs
@@ -302,7 +348,7 @@ class CLIPTextModelTester:
@require_torch
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPTextModel,) if is_torch_available() else ()
all_model_classes = (CLIPTextModel, CLIPTextModelWithProjection) if is_torch_available() else ()
fx_compatible = True
test_pruning = False
test_head_masking = False
@@ -318,6 +364,10 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_projection(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_projection(*config_and_inputs)
def test_training(self):
pass
@@ -342,6 +392,13 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
model = CLIPTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@slow
def test_model_with_projection_from_pretrained(self):
for model_name in CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = CLIPTextModelWithProjection.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "text_projection"))
class CLIPModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):