From 32e3466d38205d6c3b2264238bec371445e3d684 Mon Sep 17 00:00:00 2001 From: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Date: Mon, 13 Mar 2023 12:46:14 +0300 Subject: [PATCH] Add AutoModelForZeroShotImageClassification (#22087) Adds AutoModelForZeroShotImageClassification to transformers --- docs/source/en/model_doc/auto.mdx | 8 +++++++ src/transformers/__init__.py | 8 +++++++ src/transformers/modelcard.py | 2 ++ src/transformers/models/auto/__init__.py | 8 +++++++ src/transformers/models/auto/modeling_auto.py | 14 ++++++++++++- .../models/auto/modeling_tf_auto.py | 21 +++++++++++++++++++ src/transformers/pipelines/__init__.py | 6 ++++-- .../zero_shot_image_classification.py | 14 +++++++++---- src/transformers/utils/dummy_pt_objects.py | 10 +++++++++ src/transformers/utils/dummy_tf_objects.py | 10 +++++++++ src/transformers/utils/fx.py | 2 ++ utils/update_metadata.py | 4 ++-- 12 files changed, 98 insertions(+), 9 deletions(-) diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index 9df4fa9c99..39b0645eb5 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -258,6 +258,14 @@ The following auto classes are available for the following computer vision tasks [[autodoc]] AutoModelForUniversalSegmentation +### AutoModelForZeroShotImageClassification + +[[autodoc]] AutoModelForZeroShotImageClassification + +### TFAutoModelForZeroShotImageClassification + +[[autodoc]] TFAutoModelForZeroShotImageClassification + ### AutoModelForZeroShotObjectDetection [[autodoc]] AutoModelForZeroShotObjectDetection diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fbd1182647..14963990e4 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1001,6 +1001,7 @@ else: "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", "MODEL_FOR_VISION_2_SEQ_MAPPING", "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", "MODEL_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING", @@ -1033,6 +1034,7 @@ else: "AutoModelForVideoClassification", "AutoModelForVision2Seq", "AutoModelForVisualQuestionAnswering", + "AutoModelForZeroShotImageClassification", "AutoModelForZeroShotObjectDetection", "AutoModelWithLMHead", ] @@ -2785,6 +2787,7 @@ else: "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_VISION_2_SEQ_MAPPING", + "TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", "TF_MODEL_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING", "TFAutoModel", @@ -2803,6 +2806,7 @@ else: "TFAutoModelForTableQuestionAnswering", "TFAutoModelForTokenClassification", "TFAutoModelForVision2Seq", + "TFAutoModelForZeroShotImageClassification", "TFAutoModelWithLMHead", ] ) @@ -4514,6 +4518,7 @@ if TYPE_CHECKING: MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, @@ -4546,6 +4551,7 @@ if TYPE_CHECKING: AutoModelForVideoClassification, AutoModelForVision2Seq, AutoModelForVisualQuestionAnswering, + AutoModelForZeroShotImageClassification, AutoModelForZeroShotObjectDetection, AutoModelWithLMHead, ) @@ -5971,6 +5977,7 @@ if TYPE_CHECKING: TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, + TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING, TFAutoModel, @@ -5989,6 +5996,7 @@ if TYPE_CHECKING: TFAutoModelForTableQuestionAnswering, TFAutoModelForTokenClassification, TFAutoModelForVision2Seq, + TFAutoModelForZeroShotImageClassification, TFAutoModelWithLMHead, ) from .models.bart import ( diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index ac954272cd..e89216b0d8 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -43,6 +43,7 @@ from .models.auto.modeling_auto import ( MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, ) from .training_args import ParallelMode from .utils import ( @@ -70,6 +71,7 @@ TASK_MAPPING = { "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, "automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES}, + "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, } logger = logging.get_logger(__name__) diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 73965b657f..4eccfded5b 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -69,6 +69,7 @@ else: "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", "MODEL_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING", + "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", "AutoModel", "AutoBackbone", @@ -100,6 +101,7 @@ else: "AutoModelForVisualQuestionAnswering", "AutoModelForDocumentQuestionAnswering", "AutoModelWithLMHead", + "AutoModelForZeroShotImageClassification", "AutoModelForZeroShotObjectDetection", ] @@ -126,6 +128,7 @@ else: "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_VISION_2_SEQ_MAPPING", + "TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", "TF_MODEL_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING", "TFAutoModel", @@ -144,6 +147,7 @@ else: "TFAutoModelForTableQuestionAnswering", "TFAutoModelForTokenClassification", "TFAutoModelForVision2Seq", + "TFAutoModelForZeroShotImageClassification", "TFAutoModelWithLMHead", ] @@ -226,6 +230,7 @@ if TYPE_CHECKING: MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, @@ -258,6 +263,7 @@ if TYPE_CHECKING: AutoModelForVideoClassification, AutoModelForVision2Seq, AutoModelForVisualQuestionAnswering, + AutoModelForZeroShotImageClassification, AutoModelForZeroShotObjectDetection, AutoModelWithLMHead, ) @@ -285,6 +291,7 @@ if TYPE_CHECKING: TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, + TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING, TFAutoModel, @@ -303,6 +310,7 @@ if TYPE_CHECKING: TFAutoModelForTableQuestionAnswering, TFAutoModelForTokenClassification, TFAutoModelForVision2Seq, + TFAutoModelForZeroShotImageClassification, TFAutoModelWithLMHead, ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 446ab8ec57..08fbd47645 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -920,7 +920,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict( ] ) -_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( +MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Zero Shot Image Classification mapping ("align", "AlignModel"), @@ -955,6 +955,9 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping( MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES ) +MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES +) MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES ) @@ -1142,6 +1145,15 @@ class AutoModelForImageClassification(_BaseAutoModelClass): AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") +class AutoModelForZeroShotImageClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING + + +AutoModelForZeroShotImageClassification = auto_class_update( + AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" +) + + class AutoModelForImageSegmentation(_BaseAutoModelClass): _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 4d48e6181e..caf5ba71dc 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -209,6 +209,15 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ] ) + +TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Zero Shot Image Classification mapping + ("clip", "TFCLIPModel"), + ] +) + + TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( [ # Model for Semantic Segmentation mapping @@ -424,6 +433,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES ) +TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES +) TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES ) @@ -505,6 +517,15 @@ TFAutoModelForImageClassification = auto_class_update( ) +class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING + + +TFAutoModelForZeroShotImageClassification = auto_class_update( + TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" +) + + class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 0f83cb0dea..c8c0549a46 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -103,6 +103,7 @@ if is_tf_available(): TFAutoModelForTableQuestionAnswering, TFAutoModelForTokenClassification, TFAutoModelForVision2Seq, + TFAutoModelForZeroShotImageClassification, ) if is_torch_available(): @@ -135,6 +136,7 @@ if is_torch_available(): AutoModelForVideoClassification, AutoModelForVision2Seq, AutoModelForVisualQuestionAnswering, + AutoModelForZeroShotImageClassification, AutoModelForZeroShotObjectDetection, ) if TYPE_CHECKING: @@ -290,8 +292,8 @@ SUPPORTED_TASKS = { }, "zero-shot-image-classification": { "impl": ZeroShotImageClassificationPipeline, - "tf": (TFAutoModel,) if is_tf_available() else (), - "pt": (AutoModel,) if is_torch_available() else (), + "tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (), + "pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (), "default": { "model": { "pt": ("openai/clip-vit-base-patch32", "f4881ba"), diff --git a/src/transformers/pipelines/zero_shot_image_classification.py b/src/transformers/pipelines/zero_shot_image_classification.py index f19a548c85..8ba07eb018 100644 --- a/src/transformers/pipelines/zero_shot_image_classification.py +++ b/src/transformers/pipelines/zero_shot_image_classification.py @@ -18,9 +18,10 @@ if is_vision_available(): from ..image_utils import load_image if is_torch_available(): - pass + from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING if is_tf_available(): + from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING from ..tf_utils import stable_softmax logger = logging.get_logger(__name__) @@ -64,8 +65,11 @@ class ZeroShotImageClassificationPipeline(Pipeline): super().__init__(**kwargs) requires_backends(self, "vision") - # No specific FOR_XXX available yet - # self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING) + self.check_model_type( + TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING + if self.framework == "tf" + else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING + ) def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs): """ @@ -137,9 +141,11 @@ class ZeroShotImageClassificationPipeline(Pipeline): if self.framework == "pt": probs = logits.softmax(dim=-1).squeeze(-1) scores = probs.tolist() - else: + elif self.framework == "tf": probs = stable_softmax(logits, axis=-1) scores = probs.numpy().tolist() + else: + raise ValueError(f"Unsupported framework: {self.framework}") result = [ {"score": score, "label": candidate_label} diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 85b4010f38..62623b4066 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -526,6 +526,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None +MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = None + + MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None @@ -738,6 +741,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject): requires_backends(self, ["torch"]) +class AutoModelForZeroShotImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class AutoModelForZeroShotObjectDetection(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 16d2e6820c..55eb6599f1 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -316,6 +316,9 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None TF_MODEL_FOR_VISION_2_SEQ_MAPPING = None +TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = None + + TF_MODEL_MAPPING = None @@ -434,6 +437,13 @@ class TFAutoModelForVision2Seq(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFAutoModelForZeroShotImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFAutoModelWithLMHead(metaclass=DummyObject): _backends = ["tf"] diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index da9c43b171..9da544cb45 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -50,6 +50,7 @@ from ..models.auto.modeling_auto import ( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_MAPPING_NAMES, ) from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available @@ -79,6 +80,7 @@ def _generate_supported_model_class_names( "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES, "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, + "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, "ctc": MODEL_FOR_CTC_MAPPING_NAMES, "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, "semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, diff --git a/utils/update_metadata.py b/utils/update_metadata.py index f95a4575d1..8c34bba5d6 100644 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -93,8 +93,8 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [ ("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"), ( "zero-shot-image-classification", - "_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES", - "AutoModel", + "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES", + "AutoModelForZeroShotImageClassification", ), ("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"), ("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),