Make ASR pipeline compliant with Hub spec + add tests (#33769)
* Remove max_new_tokens arg * Add ASR pipeline to testing * make fixup * Factor the output test out into a util * Full error reporting * Full error reporting * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Small comment --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||||
|
|
||||||
@@ -269,8 +270,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||||
complete overview of generate, check the [following
|
complete overview of generate, check the [following
|
||||||
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
|
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
|
||||||
max_new_tokens (`int`, *optional*):
|
|
||||||
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
`Dict`: A dictionary with the following keys:
|
`Dict`: A dictionary with the following keys:
|
||||||
@@ -310,6 +309,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
forward_params = defaultdict(dict)
|
forward_params = defaultdict(dict)
|
||||||
if max_new_tokens is not None:
|
if max_new_tokens is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"`max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
forward_params["max_new_tokens"] = max_new_tokens
|
forward_params["max_new_tokens"] = max_new_tokens
|
||||||
if generate_kwargs is not None:
|
if generate_kwargs is not None:
|
||||||
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
|
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import time
|
|||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from dataclasses import MISSING, fields
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -2610,3 +2611,30 @@ if is_torch_available():
|
|||||||
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
|
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
|
||||||
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
|
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
|
||||||
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
|
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
|
||||||
|
|
||||||
|
|
||||||
|
def compare_pipeline_output_to_hub_spec(output, hub_spec):
|
||||||
|
missing_keys = []
|
||||||
|
unexpected_keys = []
|
||||||
|
all_field_names = {field.name for field in fields(hub_spec)}
|
||||||
|
matching_keys = sorted([key for key in output.keys() if key in all_field_names])
|
||||||
|
|
||||||
|
# Fields with a MISSING default are required and must be in the output
|
||||||
|
for field in fields(hub_spec):
|
||||||
|
if field.default is MISSING and field.name not in output:
|
||||||
|
missing_keys.append(field.name)
|
||||||
|
|
||||||
|
# All output keys must match either a required or optional field in the Hub spec
|
||||||
|
for output_key in output:
|
||||||
|
if output_key not in all_field_names:
|
||||||
|
unexpected_keys.append(output_key)
|
||||||
|
|
||||||
|
if missing_keys or unexpected_keys:
|
||||||
|
error = ["Pipeline output does not match Hub spec!"]
|
||||||
|
if matching_keys:
|
||||||
|
error.append(f"Matching keys: {matching_keys}")
|
||||||
|
if missing_keys:
|
||||||
|
error.append(f"Missing required keys in pipeline output: {missing_keys}")
|
||||||
|
if unexpected_keys:
|
||||||
|
error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
|
||||||
|
raise KeyError("\n".join(error))
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from dataclasses import fields
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import AudioClassificationOutputElement
|
from huggingface_hub import AudioClassificationOutputElement
|
||||||
@@ -21,6 +20,7 @@ from huggingface_hub import AudioClassificationOutputElement
|
|||||||
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||||
from transformers.pipelines import AudioClassificationPipeline, pipeline
|
from transformers.pipelines import AudioClassificationPipeline, pipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
compare_pipeline_output_to_hub_spec,
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_tf,
|
require_tf,
|
||||||
@@ -68,10 +68,8 @@ class AudioClassificationPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.run_torchaudio(audio_classifier)
|
self.run_torchaudio(audio_classifier)
|
||||||
|
|
||||||
spec_output_keys = {field.name for field in fields(AudioClassificationOutputElement)}
|
|
||||||
for single_output in output:
|
for single_output in output:
|
||||||
output_keys = set(single_output.keys())
|
compare_pipeline_output_to_hub_spec(single_output, AudioClassificationOutputElement)
|
||||||
self.assertEqual(spec_output_keys, output_keys, msg="Pipeline output keys do not match HF Hub spec!")
|
|
||||||
|
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
def run_torchaudio(self, audio_classifier):
|
def run_torchaudio(self, audio_classifier):
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from datasets import Audio, load_dataset
|
from datasets import Audio, load_dataset
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
from huggingface_hub import AutomaticSpeechRecognitionOutput, hf_hub_download, snapshot_download
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_FOR_CTC_MAPPING,
|
MODEL_FOR_CTC_MAPPING,
|
||||||
@@ -36,6 +36,7 @@ from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
|||||||
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
||||||
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
|
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
compare_pipeline_output_to_hub_spec,
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
@@ -86,6 +87,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
outputs = speech_recognizer(audio)
|
outputs = speech_recognizer(audio)
|
||||||
self.assertEqual(outputs, {"text": ANY(str)})
|
self.assertEqual(outputs, {"text": ANY(str)})
|
||||||
|
|
||||||
|
compare_pipeline_output_to_hub_spec(outputs, AutomaticSpeechRecognitionOutput)
|
||||||
|
|
||||||
# Striding
|
# Striding
|
||||||
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
|
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
|
||||||
if speech_recognizer.type == "ctc":
|
if speech_recognizer.type == "ctc":
|
||||||
|
|||||||
@@ -25,9 +25,9 @@ from pathlib import Path
|
|||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import get_args
|
from typing import get_args
|
||||||
|
|
||||||
from huggingface_hub import AudioClassificationInput
|
from huggingface_hub import AudioClassificationInput, AutomaticSpeechRecognitionInput
|
||||||
|
|
||||||
from transformers.pipelines import AudioClassificationPipeline
|
from transformers.pipelines import AudioClassificationPipeline, AutomaticSpeechRecognitionPipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
require_decord,
|
require_decord,
|
||||||
@@ -104,6 +104,7 @@ task_to_pipeline_and_spec_mapping = {
|
|||||||
# Adding a task to this list will cause its pipeline input signature to be checked against the corresponding
|
# Adding a task to this list will cause its pipeline input signature to be checked against the corresponding
|
||||||
# task spec in the HF Hub
|
# task spec in the HF Hub
|
||||||
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
|
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
|
||||||
|
"automatic-speech-recognition": (AutomaticSpeechRecognitionPipeline, AutomaticSpeechRecognitionInput),
|
||||||
}
|
}
|
||||||
|
|
||||||
for task, task_info in pipeline_test_mapping.items():
|
for task, task_info in pipeline_test_mapping.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user