From 36f183ebab1261b388739d628aaa0b4150068df0 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Wed, 16 Aug 2023 18:04:19 +0100 Subject: [PATCH] [ASR Pipeline] Fix init with timestamps (#25438) * [ASR Pipeline] Fix init * refactor test * change default kwarg setting * only perform checks if we have to * override init * move pre/forward/post checks to sanitize --- .../pipelines/automatic_speech_recognition.py | 110 ++++++++++++++---- ..._pipelines_automatic_speech_recognition.py | 52 +++++++++ 2 files changed, 139 insertions(+), 23 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 1f2c202ac2..fc2a9b3057 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -17,19 +17,24 @@ from typing import TYPE_CHECKING, Dict, Optional, Union import numpy as np import requests +from ..modelcard import ModelCard +from ..tokenization_utils import PreTrainedTokenizer from ..utils import is_torch_available, is_torchaudio_available, logging from .audio_utils import ffmpeg_read -from .base import ChunkPipeline +from .base import ArgumentHandler, ChunkPipeline, infer_framework_load_model if TYPE_CHECKING: from pyctcdecode import BeamSearchDecoderCTC from ..feature_extraction_sequence_utils import SequenceFeatureExtractor + from ..modeling_utils import PreTrainedModel logger = logging.get_logger(__name__) if is_torch_available(): + import torch + from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES @@ -194,14 +199,78 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): def __init__( self, - feature_extractor: Union["SequenceFeatureExtractor", str], - *, + model: "PreTrainedModel", + feature_extractor: Union["SequenceFeatureExtractor", str] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None, + modelcard: Optional[ModelCard] = None, + framework: Optional[str] = None, + task: str = "", + args_parser: ArgumentHandler = None, + device: Union[int, "torch.device"] = None, + torch_dtype: Optional[Union[str, "torch.dtype"]] = None, + binary_output: bool = False, **kwargs, ): - super().__init__(**kwargs) - self.feature_extractor = feature_extractor + if framework is None: + framework, model = infer_framework_load_model(model, config=model.config) + self.task = task + self.model = model + self.tokenizer = tokenizer + self.feature_extractor = feature_extractor + self.modelcard = modelcard + self.framework = framework + + # `accelerate` device map + hf_device_map = getattr(self.model, "hf_device_map", None) + + if hf_device_map is not None and device is not None: + raise ValueError( + "The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please " + "discard the `device` argument when creating your pipeline object." + ) + + if self.framework == "tf": + raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.") + + # We shouldn't call `model.to()` for models loaded with accelerate + if device is not None and not (isinstance(device, int) and device < 0): + self.model.to(device) + + if device is None: + if hf_device_map is not None: + # Take the first device used by `accelerate`. + device = next(iter(hf_device_map.values())) + else: + device = -1 + + if is_torch_available() and self.framework == "pt": + if isinstance(device, torch.device): + self.device = device + elif isinstance(device, str): + self.device = torch.device(device) + elif device < 0: + self.device = torch.device("cpu") + else: + self.device = torch.device(f"cuda:{device}") + else: + self.device = device if device is not None else -1 + self.torch_dtype = torch_dtype + self.binary_output = binary_output + + # Update config and generation_config with task specific parameters + task_specific_params = self.model.config.task_specific_params + if task_specific_params is not None and task in task_specific_params: + self.model.config.update(task_specific_params.get(task)) + if self.model.can_generate(): + self.model.generation_config.update(**task_specific_params.get(task)) + + self.call_count = 0 + self._batch_size = kwargs.pop("batch_size", None) + self._num_workers = kwargs.pop("num_workers", None) + + # set the model type so we can check we have the right pre- and post-processing parameters if self.model.config.model_type == "whisper": self.type = "seq2seq_whisper" elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values(): @@ -216,8 +285,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): else: self.type = "ctc" - if self.framework == "tf": - raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.") + self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy() mapping.update(MODEL_FOR_CTC_MAPPING_NAMES) @@ -301,11 +369,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): # No parameters on this pipeline right now preprocess_params = {} if chunk_length_s is not None: + if self.type == "seq2seq" and not ignore_warning: + logger.warning( + "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily" + " be entirely accurate and will have caveats. More information:" + " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...," + " ignore_warning=True)" + ) preprocess_params["chunk_length_s"] = chunk_length_s if stride_length_s is not None: preprocess_params["stride_length_s"] = stride_length_s - if ignore_warning is not None: - preprocess_params["ignore_warning"] = ignore_warning forward_params = defaultdict(dict) if max_new_tokens is not None: @@ -322,6 +395,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): if decoder_kwargs is not None: postprocess_params["decoder_kwargs"] = decoder_kwargs if return_timestamps is not None: + # Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass if self.type == "seq2seq" and return_timestamps: raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!") if self.type == "ctc_with_lm" and return_timestamps != "word": @@ -339,11 +413,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): forward_params["return_timestamps"] = return_timestamps postprocess_params["return_timestamps"] = return_timestamps if return_language is not None: + if self.type != "seq2seq_whisper": + raise ValueError("Only Whisper can return language for now.") postprocess_params["return_language"] = return_language return preprocess_params, forward_params, postprocess_params - def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False): + def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): if isinstance(inputs, str): if inputs.startswith("http://") or inputs.startswith("https://"): # We need to actually check for a real protocol, otherwise it's impossible to use a local file @@ -378,8 +454,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): extra = inputs inputs = _inputs if in_sampling_rate != self.feature_extractor.sampling_rate: - import torch - if is_torchaudio_available(): from torchaudio import functional as F else: @@ -409,14 +483,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") if chunk_length_s: - if self.type == "seq2seq" and not ignore_warning: - logger.warning( - "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily" - " be entirely accurate and will have caveats. More information:" - " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...," - " ignore_warning=True)" - ) - self._preprocess_params["ignore_warning"] = True if stride_length_s is None: stride_length_s = chunk_length_s / 6 @@ -456,6 +522,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): if generate_kwargs is None: generate_kwargs = {} + if return_timestamps and self.type == "seq2seq_whisper": generate_kwargs["return_timestamps"] = return_timestamps if return_timestamps == "word": @@ -525,9 +592,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): # Optional return types optional = {} - if return_language is not None and self.type != "seq2seq_whisper": - raise ValueError("Only whisper can return language for now.") - final_items = [] key = "logits" if self.type == "ctc_with_lm" else "tokens" stride = None diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 2d43cdbc81..7c6a950c3b 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -343,6 +343,58 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ) # fmt: on + @require_torch + def test_return_timestamps_in_init(self): + # segment-level timestamps are accepted + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny") + feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny") + + dummy_speech = np.ones(100) + + pipe = pipeline( + task="automatic-speech-recognition", + model=model, + feature_extractor=feature_extractor, + tokenizer=tokenizer, + chunk_length_s=8, + stride_length_s=1, + return_timestamps=True, + ) + + _ = pipe(dummy_speech) + + # word-level timestamps are accepted + pipe = pipeline( + task="automatic-speech-recognition", + model=model, + feature_extractor=feature_extractor, + tokenizer=tokenizer, + chunk_length_s=8, + stride_length_s=1, + return_timestamps="word", + ) + + _ = pipe(dummy_speech) + + # char-level timestamps are not accepted + with self.assertRaisesRegex( + ValueError, + "^Whisper cannot return `char` timestamps, only word level or segment level timestamps. " + "Use `return_timestamps='word'` or `return_timestamps=True` respectively.$", + ): + pipe = pipeline( + task="automatic-speech-recognition", + model=model, + feature_extractor=feature_extractor, + tokenizer=tokenizer, + chunk_length_s=8, + stride_length_s=1, + return_timestamps="char", + ) + + _ = pipe(dummy_speech) + @require_torch @slow def test_torch_whisper(self):