Add AutoModelForZeroShotImageClassification (#22087)
Adds AutoModelForZeroShotImageClassification to transformers
This commit is contained in:
@@ -258,6 +258,14 @@ The following auto classes are available for the following computer vision tasks
|
|||||||
|
|
||||||
[[autodoc]] AutoModelForUniversalSegmentation
|
[[autodoc]] AutoModelForUniversalSegmentation
|
||||||
|
|
||||||
|
### AutoModelForZeroShotImageClassification
|
||||||
|
|
||||||
|
[[autodoc]] AutoModelForZeroShotImageClassification
|
||||||
|
|
||||||
|
### TFAutoModelForZeroShotImageClassification
|
||||||
|
|
||||||
|
[[autodoc]] TFAutoModelForZeroShotImageClassification
|
||||||
|
|
||||||
### AutoModelForZeroShotObjectDetection
|
### AutoModelForZeroShotObjectDetection
|
||||||
|
|
||||||
[[autodoc]] AutoModelForZeroShotObjectDetection
|
[[autodoc]] AutoModelForZeroShotObjectDetection
|
||||||
|
|||||||
@@ -1001,6 +1001,7 @@ else:
|
|||||||
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
||||||
"MODEL_FOR_VISION_2_SEQ_MAPPING",
|
"MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||||
|
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
||||||
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
|
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
|
||||||
"MODEL_MAPPING",
|
"MODEL_MAPPING",
|
||||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||||
@@ -1033,6 +1034,7 @@ else:
|
|||||||
"AutoModelForVideoClassification",
|
"AutoModelForVideoClassification",
|
||||||
"AutoModelForVision2Seq",
|
"AutoModelForVision2Seq",
|
||||||
"AutoModelForVisualQuestionAnswering",
|
"AutoModelForVisualQuestionAnswering",
|
||||||
|
"AutoModelForZeroShotImageClassification",
|
||||||
"AutoModelForZeroShotObjectDetection",
|
"AutoModelForZeroShotObjectDetection",
|
||||||
"AutoModelWithLMHead",
|
"AutoModelWithLMHead",
|
||||||
]
|
]
|
||||||
@@ -2785,6 +2787,7 @@ else:
|
|||||||
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||||
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||||
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||||
|
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
||||||
"TF_MODEL_MAPPING",
|
"TF_MODEL_MAPPING",
|
||||||
"TF_MODEL_WITH_LM_HEAD_MAPPING",
|
"TF_MODEL_WITH_LM_HEAD_MAPPING",
|
||||||
"TFAutoModel",
|
"TFAutoModel",
|
||||||
@@ -2803,6 +2806,7 @@ else:
|
|||||||
"TFAutoModelForTableQuestionAnswering",
|
"TFAutoModelForTableQuestionAnswering",
|
||||||
"TFAutoModelForTokenClassification",
|
"TFAutoModelForTokenClassification",
|
||||||
"TFAutoModelForVision2Seq",
|
"TFAutoModelForVision2Seq",
|
||||||
|
"TFAutoModelForZeroShotImageClassification",
|
||||||
"TFAutoModelWithLMHead",
|
"TFAutoModelWithLMHead",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -4514,6 +4518,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||||
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
@@ -4546,6 +4551,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForVideoClassification,
|
AutoModelForVideoClassification,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AutoModelForVisualQuestionAnswering,
|
AutoModelForVisualQuestionAnswering,
|
||||||
|
AutoModelForZeroShotImageClassification,
|
||||||
AutoModelForZeroShotObjectDetection,
|
AutoModelForZeroShotObjectDetection,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
@@ -5971,6 +5977,7 @@ if TYPE_CHECKING:
|
|||||||
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
|
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_MAPPING,
|
TF_MODEL_MAPPING,
|
||||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
@@ -5989,6 +5996,7 @@ if TYPE_CHECKING:
|
|||||||
TFAutoModelForTableQuestionAnswering,
|
TFAutoModelForTableQuestionAnswering,
|
||||||
TFAutoModelForTokenClassification,
|
TFAutoModelForTokenClassification,
|
||||||
TFAutoModelForVision2Seq,
|
TFAutoModelForVision2Seq,
|
||||||
|
TFAutoModelForZeroShotImageClassification,
|
||||||
TFAutoModelWithLMHead,
|
TFAutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
from .models.bart import (
|
from .models.bart import (
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from .models.auto.modeling_auto import (
|
|||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
)
|
)
|
||||||
from .training_args import ParallelMode
|
from .training_args import ParallelMode
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -70,6 +71,7 @@ TASK_MAPPING = {
|
|||||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||||
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||||
"automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
|
"automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
|
||||||
|
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ else:
|
|||||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||||
"MODEL_MAPPING",
|
"MODEL_MAPPING",
|
||||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||||
|
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
||||||
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
|
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
"AutoBackbone",
|
"AutoBackbone",
|
||||||
@@ -100,6 +101,7 @@ else:
|
|||||||
"AutoModelForVisualQuestionAnswering",
|
"AutoModelForVisualQuestionAnswering",
|
||||||
"AutoModelForDocumentQuestionAnswering",
|
"AutoModelForDocumentQuestionAnswering",
|
||||||
"AutoModelWithLMHead",
|
"AutoModelWithLMHead",
|
||||||
|
"AutoModelForZeroShotImageClassification",
|
||||||
"AutoModelForZeroShotObjectDetection",
|
"AutoModelForZeroShotObjectDetection",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -126,6 +128,7 @@ else:
|
|||||||
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||||
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||||
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||||
|
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
||||||
"TF_MODEL_MAPPING",
|
"TF_MODEL_MAPPING",
|
||||||
"TF_MODEL_WITH_LM_HEAD_MAPPING",
|
"TF_MODEL_WITH_LM_HEAD_MAPPING",
|
||||||
"TFAutoModel",
|
"TFAutoModel",
|
||||||
@@ -144,6 +147,7 @@ else:
|
|||||||
"TFAutoModelForTableQuestionAnswering",
|
"TFAutoModelForTableQuestionAnswering",
|
||||||
"TFAutoModelForTokenClassification",
|
"TFAutoModelForTokenClassification",
|
||||||
"TFAutoModelForVision2Seq",
|
"TFAutoModelForVision2Seq",
|
||||||
|
"TFAutoModelForZeroShotImageClassification",
|
||||||
"TFAutoModelWithLMHead",
|
"TFAutoModelWithLMHead",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -226,6 +230,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||||
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
@@ -258,6 +263,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForVideoClassification,
|
AutoModelForVideoClassification,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AutoModelForVisualQuestionAnswering,
|
AutoModelForVisualQuestionAnswering,
|
||||||
|
AutoModelForZeroShotImageClassification,
|
||||||
AutoModelForZeroShotObjectDetection,
|
AutoModelForZeroShotObjectDetection,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
@@ -285,6 +291,7 @@ if TYPE_CHECKING:
|
|||||||
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
|
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_MAPPING,
|
TF_MODEL_MAPPING,
|
||||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
@@ -303,6 +310,7 @@ if TYPE_CHECKING:
|
|||||||
TFAutoModelForTableQuestionAnswering,
|
TFAutoModelForTableQuestionAnswering,
|
||||||
TFAutoModelForTokenClassification,
|
TFAutoModelForTokenClassification,
|
||||||
TFAutoModelForVision2Seq,
|
TFAutoModelForVision2Seq,
|
||||||
|
TFAutoModelForZeroShotImageClassification,
|
||||||
TFAutoModelWithLMHead,
|
TFAutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -920,7 +920,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Zero Shot Image Classification mapping
|
# Model for Zero Shot Image Classification mapping
|
||||||
("align", "AlignModel"),
|
("align", "AlignModel"),
|
||||||
@@ -955,6 +955,9 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
|||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||||
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||||
|
)
|
||||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
||||||
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
@@ -1142,6 +1145,15 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
|
|||||||
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForZeroShotImageClassification = auto_class_update(
|
||||||
|
AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForImageSegmentation(_BaseAutoModelClass):
|
class AutoModelForImageSegmentation(_BaseAutoModelClass):
|
||||||
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
||||||
|
|
||||||
|
|||||||
@@ -209,6 +209,15 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Zero Shot Image Classification mapping
|
||||||
|
("clip", "TFCLIPModel"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Semantic Segmentation mapping
|
# Model for Semantic Segmentation mapping
|
||||||
@@ -424,6 +433,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
|||||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
|
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||||
|
)
|
||||||
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
||||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
@@ -505,6 +517,15 @@ TFAutoModelForImageClassification = auto_class_update(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
|
||||||
|
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
TFAutoModelForZeroShotImageClassification = auto_class_update(
|
||||||
|
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
|
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
|
||||||
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
||||||
|
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ if is_tf_available():
|
|||||||
TFAutoModelForTableQuestionAnswering,
|
TFAutoModelForTableQuestionAnswering,
|
||||||
TFAutoModelForTokenClassification,
|
TFAutoModelForTokenClassification,
|
||||||
TFAutoModelForVision2Seq,
|
TFAutoModelForVision2Seq,
|
||||||
|
TFAutoModelForZeroShotImageClassification,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -135,6 +136,7 @@ if is_torch_available():
|
|||||||
AutoModelForVideoClassification,
|
AutoModelForVideoClassification,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AutoModelForVisualQuestionAnswering,
|
AutoModelForVisualQuestionAnswering,
|
||||||
|
AutoModelForZeroShotImageClassification,
|
||||||
AutoModelForZeroShotObjectDetection,
|
AutoModelForZeroShotObjectDetection,
|
||||||
)
|
)
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -290,8 +292,8 @@ SUPPORTED_TASKS = {
|
|||||||
},
|
},
|
||||||
"zero-shot-image-classification": {
|
"zero-shot-image-classification": {
|
||||||
"impl": ZeroShotImageClassificationPipeline,
|
"impl": ZeroShotImageClassificationPipeline,
|
||||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
"tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),
|
||||||
"pt": (AutoModel,) if is_torch_available() else (),
|
"pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),
|
||||||
"default": {
|
"default": {
|
||||||
"model": {
|
"model": {
|
||||||
"pt": ("openai/clip-vit-base-patch32", "f4881ba"),
|
"pt": ("openai/clip-vit-base-patch32", "f4881ba"),
|
||||||
|
|||||||
@@ -18,9 +18,10 @@ if is_vision_available():
|
|||||||
from ..image_utils import load_image
|
from ..image_utils import load_image
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
pass
|
from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
|
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||||
from ..tf_utils import stable_softmax
|
from ..tf_utils import stable_softmax
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -64,8 +65,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
requires_backends(self, "vision")
|
requires_backends(self, "vision")
|
||||||
# No specific FOR_XXX available yet
|
self.check_model_type(
|
||||||
# self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
|
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
if self.framework == "tf"
|
||||||
|
else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
|
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -137,9 +141,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
|||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
probs = logits.softmax(dim=-1).squeeze(-1)
|
probs = logits.softmax(dim=-1).squeeze(-1)
|
||||||
scores = probs.tolist()
|
scores = probs.tolist()
|
||||||
else:
|
elif self.framework == "tf":
|
||||||
probs = stable_softmax(logits, axis=-1)
|
probs = stable_softmax(logits, axis=-1)
|
||||||
scores = probs.numpy().tolist()
|
scores = probs.numpy().tolist()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||||
|
|
||||||
result = [
|
result = [
|
||||||
{"score": score, "label": candidate_label}
|
{"score": score, "label": candidate_label}
|
||||||
|
|||||||
@@ -526,6 +526,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None
|
|||||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
@@ -738,6 +741,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForZeroShotImageClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForZeroShotObjectDetection(metaclass=DummyObject):
|
class AutoModelForZeroShotObjectDetection(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -316,6 +316,9 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
|||||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = None
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
TF_MODEL_MAPPING = None
|
TF_MODEL_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
@@ -434,6 +437,13 @@ class TFAutoModelForVision2Seq(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFAutoModelForZeroShotImageClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFAutoModelWithLMHead(metaclass=DummyObject):
|
class TFAutoModelWithLMHead(metaclass=DummyObject):
|
||||||
_backends = ["tf"]
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from ..models.auto.modeling_auto import (
|
|||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
MODEL_MAPPING_NAMES,
|
MODEL_MAPPING_NAMES,
|
||||||
)
|
)
|
||||||
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
|
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
|
||||||
@@ -79,6 +80,7 @@ def _generate_supported_model_class_names(
|
|||||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||||
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
|
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
|
||||||
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||||
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
|
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
|
||||||
|
|||||||
@@ -93,8 +93,8 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
|||||||
("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"),
|
("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"),
|
||||||
(
|
(
|
||||||
"zero-shot-image-classification",
|
"zero-shot-image-classification",
|
||||||
"_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
|
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
|
||||||
"AutoModel",
|
"AutoModelForZeroShotImageClassification",
|
||||||
),
|
),
|
||||||
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
||||||
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
||||||
|
|||||||
Reference in New Issue
Block a user