Add AutoModelForZeroShotImageClassification (#22087)
Adds AutoModelForZeroShotImageClassification to transformers
This commit is contained in:
@@ -103,6 +103,7 @@ if is_tf_available():
|
||||
TFAutoModelForTableQuestionAnswering,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelForVision2Seq,
|
||||
TFAutoModelForZeroShotImageClassification,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
@@ -135,6 +136,7 @@ if is_torch_available():
|
||||
AutoModelForVideoClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
AutoModelForZeroShotImageClassification,
|
||||
AutoModelForZeroShotObjectDetection,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
@@ -290,8 +292,8 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
"zero-shot-image-classification": {
|
||||
"impl": ZeroShotImageClassificationPipeline,
|
||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||
"pt": (AutoModel,) if is_torch_available() else (),
|
||||
"tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": ("openai/clip-vit-base-patch32", "f4881ba"),
|
||||
|
||||
@@ -18,9 +18,10 @@ if is_vision_available():
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
if is_tf_available():
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||
from ..tf_utils import stable_softmax
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -64,8 +65,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
requires_backends(self, "vision")
|
||||
# No specific FOR_XXX available yet
|
||||
# self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||
if self.framework == "tf"
|
||||
else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
||||
)
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
|
||||
"""
|
||||
@@ -137,9 +141,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
if self.framework == "pt":
|
||||
probs = logits.softmax(dim=-1).squeeze(-1)
|
||||
scores = probs.tolist()
|
||||
else:
|
||||
elif self.framework == "tf":
|
||||
probs = stable_softmax(logits, axis=-1)
|
||||
scores = probs.numpy().tolist()
|
||||
else:
|
||||
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||
|
||||
result = [
|
||||
{"score": score, "label": candidate_label}
|
||||
|
||||
Reference in New Issue
Block a user