Add AutoModelForZeroShotImageClassification (#22087)

Adds AutoModelForZeroShotImageClassification to transformers
This commit is contained in:
Alara Dirik
2023-03-13 12:46:14 +03:00
committed by GitHub
parent b90fbc7e0b
commit 32e3466d38
12 changed files with 98 additions and 9 deletions

View File

@@ -103,6 +103,7 @@ if is_tf_available():
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelForZeroShotImageClassification,
)
if is_torch_available():
@@ -135,6 +136,7 @@ if is_torch_available():
AutoModelForVideoClassification,
AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotImageClassification,
AutoModelForZeroShotObjectDetection,
)
if TYPE_CHECKING:
@@ -290,8 +292,8 @@ SUPPORTED_TASKS = {
},
"zero-shot-image-classification": {
"impl": ZeroShotImageClassificationPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),
"pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("openai/clip-vit-base-patch32", "f4881ba"),

View File

@@ -18,9 +18,10 @@ if is_vision_available():
from ..image_utils import load_image
if is_torch_available():
pass
from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
from ..tf_utils import stable_softmax
logger = logging.get_logger(__name__)
@@ -64,8 +65,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
super().__init__(**kwargs)
requires_backends(self, "vision")
# No specific FOR_XXX available yet
# self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
self.check_model_type(
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
if self.framework == "tf"
else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
)
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
"""
@@ -137,9 +141,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
if self.framework == "pt":
probs = logits.softmax(dim=-1).squeeze(-1)
scores = probs.tolist()
else:
elif self.framework == "tf":
probs = stable_softmax(logits, axis=-1)
scores = probs.numpy().tolist()
else:
raise ValueError(f"Unsupported framework: {self.framework}")
result = [
{"score": score, "label": candidate_label}