Fix and document Zero Shot Image Classification (#16079)

This commit is contained in:
Omar Sanseviero
2022-03-14 08:50:36 +01:00
committed by GitHub
parent 6e1e88fd38
commit 802984ad42
3 changed files with 4 additions and 2 deletions

View File

@@ -39,6 +39,7 @@ There are two categories of pipeline abstractions to be aware about:
- [`TokenClassificationPipeline`] - [`TokenClassificationPipeline`]
- [`TranslationPipeline`] - [`TranslationPipeline`]
- [`ZeroShotClassificationPipeline`] - [`ZeroShotClassificationPipeline`]
- [`ZeroShotImageClassificationPipeline`]
## The pipeline abstraction ## The pipeline abstraction

View File

@@ -245,7 +245,7 @@ SUPPORTED_TASKS = {
"impl": ZeroShotImageClassificationPipeline, "impl": ZeroShotImageClassificationPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (), "tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (), "pt": (AutoModel,) if is_torch_available() else (),
"default": {"pt": "openai/clip-vit-base-patch32", "tf": "openai/clip-vit-base-patch32"}, "default": {"model": {"pt": "openai/clip-vit-base-patch32", "tf": "openai/clip-vit-base-patch32"}},
"type": "multimodal", "type": "multimodal",
}, },
"conversational": { "conversational": {
@@ -346,6 +346,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
- `"translation_xx_to_yy"` - `"translation_xx_to_yy"`
- `"summarization"` - `"summarization"`
- `"zero-shot-classification"` - `"zero-shot-classification"`
- `"zero-shot-image-classification"`
Returns: Returns:
(task_defaults`dict`, task_options: (`tuple`, None)) The actual dictionary required to initialize the pipeline (task_defaults`dict`, task_options: (`tuple`, None)) The actual dictionary required to initialize the pipeline

View File

@@ -35,7 +35,7 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline):
`"zero-shot-image-classification"`. `"zero-shot-image-classification"`.
See the list of available models on See the list of available models on
[huggingface.co/models](https://huggingface.co/models?filter=zer-shot-image-classification). [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-image-classification).
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):