From a43e84cb3b78fcac3d5d9374a8488f74f3f19245 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 1 Oct 2024 18:15:04 +0100 Subject: [PATCH] Make ASR pipeline compliant with Hub spec + add tests (#33769) * Remove max_new_tokens arg * Add ASR pipeline to testing * make fixup * Factor the output test out into a util * Full error reporting * Full error reporting * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Lysandre Debut * Small comment --------- Co-authored-by: Lysandre Debut --- .../pipelines/automatic_speech_recognition.py | 7 +++-- src/transformers/testing_utils.py | 28 +++++++++++++++++++ .../test_pipelines_audio_classification.py | 6 ++-- ..._pipelines_automatic_speech_recognition.py | 5 +++- tests/test_pipeline_mixin.py | 5 ++-- 5 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 9b82b67820..f4ffdf6445 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections import defaultdict from typing import TYPE_CHECKING, Dict, Optional, Union @@ -269,8 +270,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a complete overview of generate, check the [following guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). - max_new_tokens (`int`, *optional*): - The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. Return: `Dict`: A dictionary with the following keys: @@ -310,6 +309,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): forward_params = defaultdict(dict) if max_new_tokens is not None: + warnings.warn( + "`max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.", + FutureWarning, + ) forward_params["max_new_tokens"] = max_new_tokens if generate_kwargs is not None: if max_new_tokens is not None and "max_new_tokens" in generate_kwargs: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index a5f257c653..4986de42e0 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -31,6 +31,7 @@ import time import unittest from collections import defaultdict from collections.abc import Mapping +from dataclasses import MISSING, fields from functools import wraps from io import StringIO from pathlib import Path @@ -2610,3 +2611,30 @@ if is_torch_available(): update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") + + +def compare_pipeline_output_to_hub_spec(output, hub_spec): + missing_keys = [] + unexpected_keys = [] + all_field_names = {field.name for field in fields(hub_spec)} + matching_keys = sorted([key for key in output.keys() if key in all_field_names]) + + # Fields with a MISSING default are required and must be in the output + for field in fields(hub_spec): + if field.default is MISSING and field.name not in output: + missing_keys.append(field.name) + + # All output keys must match either a required or optional field in the Hub spec + for output_key in output: + if output_key not in all_field_names: + unexpected_keys.append(output_key) + + if missing_keys or unexpected_keys: + error = ["Pipeline output does not match Hub spec!"] + if matching_keys: + error.append(f"Matching keys: {matching_keys}") + if missing_keys: + error.append(f"Missing required keys in pipeline output: {missing_keys}") + if unexpected_keys: + error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}") + raise KeyError("\n".join(error)) diff --git a/tests/pipelines/test_pipelines_audio_classification.py b/tests/pipelines/test_pipelines_audio_classification.py index a8e36b4a07..37990d0074 100644 --- a/tests/pipelines/test_pipelines_audio_classification.py +++ b/tests/pipelines/test_pipelines_audio_classification.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest -from dataclasses import fields import numpy as np from huggingface_hub import AudioClassificationOutputElement @@ -21,6 +20,7 @@ from huggingface_hub import AudioClassificationOutputElement from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING from transformers.pipelines import AudioClassificationPipeline, pipeline from transformers.testing_utils import ( + compare_pipeline_output_to_hub_spec, is_pipeline_test, nested_simplify, require_tf, @@ -68,10 +68,8 @@ class AudioClassificationPipelineTests(unittest.TestCase): self.run_torchaudio(audio_classifier) - spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)} for single_output in output: - output_keys = set(single_output.keys()) - self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!") + compare_pipeline_output_to_hub_spec(single_output, AudioClassificationOutputElement) @require_torchaudio def run_torchaudio(self, audio_classifier): diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index c12292fc33..aecb96a5ee 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -18,7 +18,7 @@ import unittest import numpy as np import pytest from datasets import Audio, load_dataset -from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub import AutomaticSpeechRecognitionOutput, hf_hub_download, snapshot_download from transformers import ( MODEL_FOR_CTC_MAPPING, @@ -36,6 +36,7 @@ from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline from transformers.pipelines.audio_utils import chunk_bytes_iter from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter from transformers.testing_utils import ( + compare_pipeline_output_to_hub_spec, is_pipeline_test, is_torch_available, nested_simplify, @@ -86,6 +87,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): outputs = speech_recognizer(audio) self.assertEqual(outputs, {"text": ANY(str)}) + compare_pipeline_output_to_hub_spec(outputs, AutomaticSpeechRecognitionOutput) + # Striding audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate} if speech_recognizer.type == "ctc": diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py index 74e685fb11..e3c650a0e0 100644 --- a/tests/test_pipeline_mixin.py +++ b/tests/test_pipeline_mixin.py @@ -25,9 +25,9 @@ from pathlib import Path from textwrap import dedent from typing import get_args -from huggingface_hub import AudioClassificationInput +from huggingface_hub import AudioClassificationInput, AutomaticSpeechRecognitionInput -from transformers.pipelines import AudioClassificationPipeline +from transformers.pipelines import AudioClassificationPipeline, AutomaticSpeechRecognitionPipeline from transformers.testing_utils import ( is_pipeline_test, require_decord, @@ -104,6 +104,7 @@ task_to_pipeline_and_spec_mapping = { # Adding a task to this list will cause its pipeline input signature to be checked against the corresponding # task spec in the HF Hub "audio-classification": (AudioClassificationPipeline, AudioClassificationInput), + "automatic-speech-recognition": (AutomaticSpeechRecognitionPipeline, AutomaticSpeechRecognitionInput), } for task, task_info in pipeline_test_mapping.items():