Sync video classification pipeline with huggingface_hub spec (#34288)
* Sync video classification pipeline * Add disclaimer
This commit is contained in:
@@ -14,11 +14,12 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import VideoClassificationOutputElement, hf_hub_download
|
||||
|
||||
from transformers import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, VideoMAEFeatureExtractor
|
||||
from transformers.pipelines import VideoClassificationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_av,
|
||||
@@ -76,6 +77,8 @@ class VideoClassificationPipelineTests(unittest.TestCase):
|
||||
{"score": ANY(float), "label": ANY(str)},
|
||||
],
|
||||
)
|
||||
for element in outputs:
|
||||
compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
@@ -93,6 +96,9 @@ class VideoClassificationPipelineTests(unittest.TestCase):
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
|
||||
)
|
||||
for output in outputs:
|
||||
for element in output:
|
||||
compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement)
|
||||
|
||||
outputs = video_classifier(
|
||||
[
|
||||
@@ -108,6 +114,9 @@ class VideoClassificationPipelineTests(unittest.TestCase):
|
||||
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
|
||||
],
|
||||
)
|
||||
for output in outputs:
|
||||
for element in output:
|
||||
compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip
|
||||
|
||||
@@ -34,6 +34,7 @@ from huggingface_hub import (
|
||||
ImageToTextInput,
|
||||
ObjectDetectionInput,
|
||||
QuestionAnsweringInput,
|
||||
VideoClassificationInput,
|
||||
ZeroShotImageClassificationInput,
|
||||
)
|
||||
|
||||
@@ -47,6 +48,7 @@ from transformers.pipelines import (
|
||||
ImageToTextPipeline,
|
||||
ObjectDetectionPipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
VideoClassificationPipeline,
|
||||
ZeroShotImageClassificationPipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
@@ -132,6 +134,7 @@ task_to_pipeline_and_spec_mapping = {
|
||||
"image-to-text": (ImageToTextPipeline, ImageToTextInput),
|
||||
"object-detection": (ObjectDetectionPipeline, ObjectDetectionInput),
|
||||
"question-answering": (QuestionAnsweringPipeline, QuestionAnsweringInput),
|
||||
"video-classification": (VideoClassificationPipeline, VideoClassificationInput),
|
||||
"zero-shot-image-classification": (ZeroShotImageClassificationPipeline, ZeroShotImageClassificationInput),
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user