Make ASR pipeline compliant with Hub spec + add tests (#33769)
* Remove max_new_tokens arg * Add ASR pipeline to testing * make fixup * Factor the output test out into a util * Full error reporting * Full error reporting * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Small comment --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from dataclasses import fields
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import AudioClassificationOutputElement
|
||||
@@ -21,6 +20,7 @@ from huggingface_hub import AudioClassificationOutputElement
|
||||
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
from transformers.pipelines import AudioClassificationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
@@ -68,10 +68,8 @@ class AudioClassificationPipelineTests(unittest.TestCase):
|
||||
|
||||
self.run_torchaudio(audio_classifier)
|
||||
|
||||
spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)}
|
||||
for single_output in output:
|
||||
output_keys = set(single_output.keys())
|
||||
self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!")
|
||||
compare_pipeline_output_to_hub_spec(single_output, AudioClassificationOutputElement)
|
||||
|
||||
@require_torchaudio
|
||||
def run_torchaudio(self, audio_classifier):
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datasets import Audio, load_dataset
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from huggingface_hub import AutomaticSpeechRecognitionOutput, hf_hub_download, snapshot_download
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_CTC_MAPPING,
|
||||
@@ -36,6 +36,7 @@ from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
||||
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
|
||||
from transformers.testing_utils import (
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
is_pipeline_test,
|
||||
is_torch_available,
|
||||
nested_simplify,
|
||||
@@ -86,6 +87,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
outputs = speech_recognizer(audio)
|
||||
self.assertEqual(outputs, {"text": ANY(str)})
|
||||
|
||||
compare_pipeline_output_to_hub_spec(outputs, AutomaticSpeechRecognitionOutput)
|
||||
|
||||
# Striding
|
||||
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
|
||||
if speech_recognizer.type == "ctc":
|
||||
|
||||
@@ -25,9 +25,9 @@ from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import get_args
|
||||
|
||||
from huggingface_hub import AudioClassificationInput
|
||||
from huggingface_hub import AudioClassificationInput, AutomaticSpeechRecognitionInput
|
||||
|
||||
from transformers.pipelines import AudioClassificationPipeline
|
||||
from transformers.pipelines import AudioClassificationPipeline, AutomaticSpeechRecognitionPipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
require_decord,
|
||||
@@ -104,6 +104,7 @@ task_to_pipeline_and_spec_mapping = {
|
||||
# Adding a task to this list will cause its pipeline input signature to be checked against the corresponding
|
||||
# task spec in the HF Hub
|
||||
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
|
||||
"automatic-speech-recognition": (AutomaticSpeechRecognitionPipeline, AutomaticSpeechRecognitionInput),
|
||||
}
|
||||
|
||||
for task, task_info in pipeline_test_mapping.items():
|
||||
|
||||
Reference in New Issue
Block a user