Make ASR pipeline compliant with Hub spec + add tests (#33769)

* Remove max_new_tokens arg

* Add ASR pipeline to testing

* make fixup

* Factor the output test out into a util

* Full error reporting

* Full error reporting

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Small comment

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Matt
2024-10-01 18:15:04 +01:00
committed by GitHub
parent 0256520794
commit a43e84cb3b
5 changed files with 42 additions and 9 deletions

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Union from typing import TYPE_CHECKING, Dict, Optional, Union
@@ -269,8 +270,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
complete overview of generate, check the [following complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
Return: Return:
`Dict`: A dictionary with the following keys: `Dict`: A dictionary with the following keys:
@@ -310,6 +309,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
forward_params = defaultdict(dict) forward_params = defaultdict(dict)
if max_new_tokens is not None: if max_new_tokens is not None:
warnings.warn(
"`max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.",
FutureWarning,
)
forward_params["max_new_tokens"] = max_new_tokens forward_params["max_new_tokens"] = max_new_tokens
if generate_kwargs is not None: if generate_kwargs is not None:
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs: if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:

View File

@@ -31,6 +31,7 @@ import time
import unittest import unittest
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import MISSING, fields
from functools import wraps from functools import wraps
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
@@ -2610,3 +2611,30 @@ if is_torch_available():
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
def compare_pipeline_output_to_hub_spec(output, hub_spec):
missing_keys = []
unexpected_keys = []
all_field_names = {field.name for field in fields(hub_spec)}
matching_keys = sorted([key for key in output.keys() if key in all_field_names])
# Fields with a MISSING default are required and must be in the output
for field in fields(hub_spec):
if field.default is MISSING and field.name not in output:
missing_keys.append(field.name)
# All output keys must match either a required or optional field in the Hub spec
for output_key in output:
if output_key not in all_field_names:
unexpected_keys.append(output_key)
if missing_keys or unexpected_keys:
error = ["Pipeline output does not match Hub spec!"]
if matching_keys:
error.append(f"Matching keys: {matching_keys}")
if missing_keys:
error.append(f"Missing required keys in pipeline output: {missing_keys}")
if unexpected_keys:
error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
raise KeyError("\n".join(error))

View File

@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from dataclasses import fields
import numpy as np import numpy as np
from huggingface_hub import AudioClassificationOutputElement from huggingface_hub import AudioClassificationOutputElement
@@ -21,6 +20,7 @@ from huggingface_hub import AudioClassificationOutputElement
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
from transformers.pipelines import AudioClassificationPipeline, pipeline from transformers.pipelines import AudioClassificationPipeline, pipeline
from transformers.testing_utils import ( from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test, is_pipeline_test,
nested_simplify, nested_simplify,
require_tf, require_tf,
@@ -68,10 +68,8 @@ class AudioClassificationPipelineTests(unittest.TestCase):
self.run_torchaudio(audio_classifier) self.run_torchaudio(audio_classifier)
spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)}
for single_output in output: for single_output in output:
output_keys = set(single_output.keys()) compare_pipeline_output_to_hub_spec(single_output, AudioClassificationOutputElement)
self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!")
@require_torchaudio @require_torchaudio
def run_torchaudio(self, audio_classifier): def run_torchaudio(self, audio_classifier):

View File

@@ -18,7 +18,7 @@ import unittest
import numpy as np import numpy as np
import pytest import pytest
from datasets import Audio, load_dataset from datasets import Audio, load_dataset
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import AutomaticSpeechRecognitionOutput, hf_hub_download, snapshot_download
from transformers import ( from transformers import (
MODEL_FOR_CTC_MAPPING, MODEL_FOR_CTC_MAPPING,
@@ -36,6 +36,7 @@ from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter from transformers.pipelines.audio_utils import chunk_bytes_iter
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
from transformers.testing_utils import ( from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test, is_pipeline_test,
is_torch_available, is_torch_available,
nested_simplify, nested_simplify,
@@ -86,6 +87,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
outputs = speech_recognizer(audio) outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)}) self.assertEqual(outputs, {"text": ANY(str)})
compare_pipeline_output_to_hub_spec(outputs, AutomaticSpeechRecognitionOutput)
# Striding # Striding
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate} audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
if speech_recognizer.type == "ctc": if speech_recognizer.type == "ctc":

View File

@@ -25,9 +25,9 @@ from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import get_args from typing import get_args
from huggingface_hub import AudioClassificationInput from huggingface_hub import AudioClassificationInput, AutomaticSpeechRecognitionInput
from transformers.pipelines import AudioClassificationPipeline from transformers.pipelines import AudioClassificationPipeline, AutomaticSpeechRecognitionPipeline
from transformers.testing_utils import ( from transformers.testing_utils import (
is_pipeline_test, is_pipeline_test,
require_decord, require_decord,
@@ -104,6 +104,7 @@ task_to_pipeline_and_spec_mapping = {
# Adding a task to this list will cause its pipeline input signature to be checked against the corresponding # Adding a task to this list will cause its pipeline input signature to be checked against the corresponding
# task spec in the HF Hub # task spec in the HF Hub
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput), "audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
"automatic-speech-recognition": (AutomaticSpeechRecognitionPipeline, AutomaticSpeechRecognitionInput),
} }
for task, task_info in pipeline_test_mapping.items(): for task, task_info in pipeline_test_mapping.items():