diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 685e8e16e5..8b06009a4c 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -874,6 +874,9 @@ def pipeline( if feature_extractor is not None: kwargs["feature_extractor"] = feature_extractor + if torch_dtype is not None: + kwargs["torch_dtype"] = torch_dtype + if device is not None: kwargs["device"] = device diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 7163f89420..9e4f45fd79 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -52,13 +52,15 @@ def rescale_stride(stride, ratio): return new_strides -def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right): +def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None): inputs_len = inputs.shape[0] step = chunk_len - stride_left - stride_right for i in range(0, inputs_len, step): # add start and end paddings to the chunk chunk = inputs[i : i + chunk_len] processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") + if dtype is not None: + processed = processed.to(dtype=dtype) _stride_left = 0 if i == 0 else stride_left is_last = i + step + stride_left >= inputs_len _stride_right = 0 if is_last else stride_right @@ -240,6 +242,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): preprocess_params["stride_length_s"] = kwargs["stride_length_s"] if "ignore_warning" in kwargs: preprocess_params["ignore_warning"] = kwargs["ignore_warning"] + if "torch_dtype" in kwargs: + preprocess_params["dtype"] = kwargs["torch_dtype"] postprocess_params = {} if "decoder_kwargs" in kwargs: @@ -249,7 +253,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): return preprocess_params, {}, postprocess_params - def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False): + def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None): if isinstance(inputs, str): if inputs.startswith("http://") or inputs.startswith("https://"): # We need to actually check for a real protocol, otherwise it's impossible to use a local file @@ -332,12 +336,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): raise ValueError("Chunk length must be superior to stride length") # make sure that - for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right): + for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, dtype): yield item else: processed = self.feature_extractor( inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" ) + if dtype is not None: + processed = processed.to(dtype=dtype) if stride is not None: if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values(): raise ValueError("Stride is only usable with CTC models, try removing it") @@ -366,6 +372,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): # `generate` magic to create the mask automatically won't work, we basically need to help # it here. attention_mask = model_inputs.pop("attention_mask", None) + tokens = self.model.generate( encoder_outputs=encoder(inputs, attention_mask=attention_mask), attention_mask=attention_mask, diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 5dd1c1748e..88a9a088f0 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -145,6 +145,19 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"): _ = speech_recognizer(waveform, return_timestamps="char") + @slow + @require_torch + def test_whisper_fp16(self): + if not torch.cuda.is_available(): + self.skipTest("Cuda is necessary for this test") + speech_recognizer = pipeline( + model="openai/whisper-base", + device=0, + torch_dtype=torch.float16, + ) + waveform = np.tile(np.arange(1000, dtype=np.float32), 34) + speech_recognizer(waveform) + @require_torch def test_small_model_pt_seq2seq(self): speech_recognizer = pipeline(