Make ASR pipeline compliant with Hub spec + add tests (#33769)

* Remove max_new_tokens arg

* Add ASR pipeline to testing

* make fixup

* Factor the output test out into a util

* Full error reporting

* Full error reporting

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Small comment

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Matt
2024-10-01 18:15:04 +01:00
committed by GitHub
parent 0256520794
commit a43e84cb3b
5 changed files with 42 additions and 9 deletions

View File

@@ -18,7 +18,7 @@ import unittest
import numpy as np
import pytest
from datasets import Audio, load_dataset
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub import AutomaticSpeechRecognitionOutput, hf_hub_download, snapshot_download
from transformers import (
MODEL_FOR_CTC_MAPPING,
@@ -36,6 +36,7 @@ from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
is_torch_available,
nested_simplify,
@@ -86,6 +87,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
compare_pipeline_output_to_hub_spec(outputs, AutomaticSpeechRecognitionOutput)
# Striding
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
if speech_recognizer.type == "ctc":