Add SiglipForImageClassification and CLIPForImageClassification (#28952)
* First draft * Add CLIPForImageClassification * Remove scripts * Fix doctests
This commit is contained in:
@@ -51,6 +51,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
CLIPForImageClassification,
|
||||
CLIPModel,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
@@ -744,6 +745,65 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class CLIPForImageClassificationModelTester(CLIPModelTester):
|
||||
def __init__(self, parent):
|
||||
super().__init__(parent)
|
||||
self.batch_size = self.vision_model_tester.batch_size
|
||||
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
|
||||
self.hidden_size = self.vision_model_tester.hidden_size
|
||||
self.seq_length = self.vision_model_tester.seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class CLIPForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (CLIPForImageClassification,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-classification": CLIPForImageClassification} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = CLIPForImageClassificationModelTester(self)
|
||||
|
||||
@unittest.skip(reason="CLIPForImageClassification does not support inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CLIPForImageClassification does not support inputs_embeds")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CLIP uses the same initialization scheme as the Flax original implementation")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Siglip model. """
|
||||
""" Testing suite for the PyTorch SigLIP model. """
|
||||
|
||||
|
||||
import inspect
|
||||
@@ -47,7 +47,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import SiglipModel, SiglipTextModel, SiglipVisionModel
|
||||
from transformers import SiglipForImageClassification, SiglipModel, SiglipTextModel, SiglipVisionModel
|
||||
from transformers.models.siglip.modeling_siglip import SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@@ -584,6 +584,65 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class SiglipForImageClassificationModelTester(SiglipModelTester):
|
||||
def __init__(self, parent):
|
||||
super().__init__(parent)
|
||||
self.batch_size = self.vision_model_tester.batch_size
|
||||
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
|
||||
self.hidden_size = self.vision_model_tester.hidden_size
|
||||
self.seq_length = self.vision_model_tester.seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SiglipForImageClassification,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-classification": SiglipForImageClassification} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SiglipForImageClassificationModelTester(self)
|
||||
|
||||
@unittest.skip(reason="SiglipForImageClassification does not support inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SiglipForImageClassification does not support inputs_embeds")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SiglipForImageClassification does not support gradient checkpointing yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
|
||||
Reference in New Issue
Block a user