Sync video classification pipeline with huggingface_hub spec (#34288)

* Sync video classification pipeline

* Add disclaimer
This commit is contained in:
Matt
2024-10-22 13:33:49 +01:00
committed by GitHub
parent 93352e81f5
commit 681fc43713
3 changed files with 61 additions and 7 deletions

View File

@@ -14,11 +14,12 @@
import unittest
from huggingface_hub import hf_hub_download
from huggingface_hub import VideoClassificationOutputElement, hf_hub_download
from transformers import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, VideoMAEFeatureExtractor
from transformers.pipelines import VideoClassificationPipeline, pipeline
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
nested_simplify,
require_av,
@@ -76,6 +77,8 @@ class VideoClassificationPipelineTests(unittest.TestCase):
{"score": ANY(float), "label": ANY(str)},
],
)
for element in outputs:
compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement)
@require_torch
def test_small_model_pt(self):
@@ -93,6 +96,9 @@ class VideoClassificationPipelineTests(unittest.TestCase):
nested_simplify(outputs, decimals=4),
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
)
for output in outputs:
for element in output:
compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement)
outputs = video_classifier(
[
@@ -108,6 +114,9 @@ class VideoClassificationPipelineTests(unittest.TestCase):
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
],
)
for output in outputs:
for element in output:
compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement)
@require_tf
@unittest.skip

View File

@@ -34,6 +34,7 @@ from huggingface_hub import (
ImageToTextInput,
ObjectDetectionInput,
QuestionAnsweringInput,
VideoClassificationInput,
ZeroShotImageClassificationInput,
)
@@ -47,6 +48,7 @@ from transformers.pipelines import (
ImageToTextPipeline,
ObjectDetectionPipeline,
QuestionAnsweringPipeline,
VideoClassificationPipeline,
ZeroShotImageClassificationPipeline,
)
from transformers.testing_utils import (
@@ -132,6 +134,7 @@ task_to_pipeline_and_spec_mapping = {
"image-to-text": (ImageToTextPipeline, ImageToTextInput),
"object-detection": (ObjectDetectionPipeline, ObjectDetectionInput),
"question-answering": (QuestionAnsweringPipeline, QuestionAnsweringInput),
"video-classification": (VideoClassificationPipeline, VideoClassificationInput),
"zero-shot-image-classification": (ZeroShotImageClassificationPipeline, ZeroShotImageClassificationInput),
}