From 802984ad42cfb368080904c3a751f62c92aab8eb Mon Sep 17 00:00:00 2001 From: Omar Sanseviero Date: Mon, 14 Mar 2022 08:50:36 +0100 Subject: [PATCH] Fix and document Zero Shot Image Classification (#16079) --- docs/source/main_classes/pipelines.mdx | 1 + src/transformers/pipelines/__init__.py | 3 ++- src/transformers/pipelines/zero_shot_image_classification.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/main_classes/pipelines.mdx b/docs/source/main_classes/pipelines.mdx index b5c51229ca..af82d16750 100644 --- a/docs/source/main_classes/pipelines.mdx +++ b/docs/source/main_classes/pipelines.mdx @@ -39,6 +39,7 @@ There are two categories of pipeline abstractions to be aware about: - [`TokenClassificationPipeline`] - [`TranslationPipeline`] - [`ZeroShotClassificationPipeline`] + - [`ZeroShotImageClassificationPipeline`] ## The pipeline abstraction diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index c43627e3ac..94d422f3ab 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -245,7 +245,7 @@ SUPPORTED_TASKS = { "impl": ZeroShotImageClassificationPipeline, "tf": (TFAutoModel,) if is_tf_available() else (), "pt": (AutoModel,) if is_torch_available() else (), - "default": {"pt": "openai/clip-vit-base-patch32", "tf": "openai/clip-vit-base-patch32"}, + "default": {"model": {"pt": "openai/clip-vit-base-patch32", "tf": "openai/clip-vit-base-patch32"}}, "type": "multimodal", }, "conversational": { @@ -346,6 +346,7 @@ def check_task(task: str) -> Tuple[Dict, Any]: - `"translation_xx_to_yy"` - `"summarization"` - `"zero-shot-classification"` + - `"zero-shot-image-classification"` Returns: (task_defaults`dict`, task_options: (`tuple`, None)) The actual dictionary required to initialize the pipeline diff --git a/src/transformers/pipelines/zero_shot_image_classification.py b/src/transformers/pipelines/zero_shot_image_classification.py index fb4036a9fa..859d942b23 100644 --- a/src/transformers/pipelines/zero_shot_image_classification.py +++ b/src/transformers/pipelines/zero_shot_image_classification.py @@ -35,7 +35,7 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline): `"zero-shot-image-classification"`. See the list of available models on - [huggingface.co/models](https://huggingface.co/models?filter=zer-shot-image-classification). + [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-image-classification). """ def __init__(self, **kwargs):