Make ASR pipeline compliant with Hub spec + add tests (#33769)

* Remove max_new_tokens arg

* Add ASR pipeline to testing

* make fixup

* Factor the output test out into a util

* Full error reporting

* Full error reporting

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Small comment

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Matt
2024-10-01 18:15:04 +01:00
committed by GitHub
parent 0256520794
commit a43e84cb3b
5 changed files with 42 additions and 9 deletions

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Union
@@ -269,8 +270,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
Return:
`Dict`: A dictionary with the following keys:
@@ -310,6 +309,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
forward_params = defaultdict(dict)
if max_new_tokens is not None:
warnings.warn(
"`max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.",
FutureWarning,
)
forward_params["max_new_tokens"] = max_new_tokens
if generate_kwargs is not None:
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:

View File

@@ -31,6 +31,7 @@ import time
import unittest
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import MISSING, fields
from functools import wraps
from io import StringIO
from pathlib import Path
@@ -2610,3 +2611,30 @@ if is_torch_available():
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
def compare_pipeline_output_to_hub_spec(output, hub_spec):
missing_keys = []
unexpected_keys = []
all_field_names = {field.name for field in fields(hub_spec)}
matching_keys = sorted([key for key in output.keys() if key in all_field_names])
# Fields with a MISSING default are required and must be in the output
for field in fields(hub_spec):
if field.default is MISSING and field.name not in output:
missing_keys.append(field.name)
# All output keys must match either a required or optional field in the Hub spec
for output_key in output:
if output_key not in all_field_names:
unexpected_keys.append(output_key)
if missing_keys or unexpected_keys:
error = ["Pipeline output does not match Hub spec!"]
if matching_keys:
error.append(f"Matching keys: {matching_keys}")
if missing_keys:
error.append(f"Missing required keys in pipeline output: {missing_keys}")
if unexpected_keys:
error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
raise KeyError("\n".join(error))

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import unittest
from dataclasses import fields
import numpy as np
from huggingface_hub import AudioClassificationOutputElement
@@ -21,6 +20,7 @@ from huggingface_hub import AudioClassificationOutputElement
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
from transformers.pipelines import AudioClassificationPipeline, pipeline
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
nested_simplify,
require_tf,
@@ -68,10 +68,8 @@ class AudioClassificationPipelineTests(unittest.TestCase):
self.run_torchaudio(audio_classifier)
spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)}
for single_output in output:
output_keys = set(single_output.keys())
self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!")
compare_pipeline_output_to_hub_spec(single_output, AudioClassificationOutputElement)
@require_torchaudio
def run_torchaudio(self, audio_classifier):

View File

@@ -18,7 +18,7 @@ import unittest
import numpy as np
import pytest
from datasets import Audio, load_dataset
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub import AutomaticSpeechRecognitionOutput, hf_hub_download, snapshot_download
from transformers import (
MODEL_FOR_CTC_MAPPING,
@@ -36,6 +36,7 @@ from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
is_torch_available,
nested_simplify,
@@ -86,6 +87,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
compare_pipeline_output_to_hub_spec(outputs, AutomaticSpeechRecognitionOutput)
# Striding
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
if speech_recognizer.type == "ctc":

View File

@@ -25,9 +25,9 @@ from pathlib import Path
from textwrap import dedent
from typing import get_args
from huggingface_hub import AudioClassificationInput
from huggingface_hub import AudioClassificationInput, AutomaticSpeechRecognitionInput
from transformers.pipelines import AudioClassificationPipeline
from transformers.pipelines import AudioClassificationPipeline, AutomaticSpeechRecognitionPipeline
from transformers.testing_utils import (
is_pipeline_test,
require_decord,
@@ -104,6 +104,7 @@ task_to_pipeline_and_spec_mapping = {
# Adding a task to this list will cause its pipeline input signature to be checked against the corresponding
# task spec in the HF Hub
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
"automatic-speech-recognition": (AutomaticSpeechRecognitionPipeline, AutomaticSpeechRecognitionInput),
}
for task, task_info in pipeline_test_mapping.items():