From 814619f54f677df79a337396794325f13f96251f Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 21 Dec 2023 11:04:05 +0000 Subject: [PATCH] [Whisper] Use torch for stft if available (#26119) * [Whisper] Use torch for stft if available * update docstring * mock patch decorator * fit on one line --- .../whisper/feature_extraction_whisper.py | 30 ++++++++++++++-- .../test_feature_extraction_whisper.py | 36 +++++++++++++------ 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index b6c171ce93..42104c3293 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -19,12 +19,16 @@ from typing import List, Optional, Union import numpy as np +from ... import is_torch_available from ...audio_utils import mel_filter_bank, spectrogram, window_function from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_utils import BatchFeature from ...utils import TensorType, logging +if is_torch_available(): + import torch + logger = logging.get_logger(__name__) @@ -109,6 +113,24 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): log_spec = (log_spec + 4.0) / 4.0 return log_spec + def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray: + """ + Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation. + """ + waveform = torch.from_numpy(waveform).type(torch.float32) + + window = torch.hann_window(self.n_fft) + stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32) + mel_spec = mel_filters.T @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec.numpy() + @staticmethod # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm def zero_mean_unit_var_norm( @@ -146,7 +168,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): **kwargs, ) -> BatchFeature: """ - Main method to featurize and prepare for the model one or several sequence(s). + Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for + the STFT computation if available, otherwise a slower NumPy based one. Args: raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): @@ -246,7 +269,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): # make sure list is in array format input_features = padded_inputs.get("input_features").transpose(2, 0, 1) - input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]] + extract_fbank_features = ( + self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features + ) + input_features = [extract_fbank_features(waveform) for waveform in input_features[0]] if isinstance(input_features[0], List): padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index 90cbfc21c0..77c7a9be3d 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -23,16 +23,13 @@ import unittest import numpy as np from datasets import load_dataset -from transformers import is_speech_available -from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio +from transformers import WhisperFeatureExtractor +from transformers.testing_utils import check_json_file_has_correct_format, require_torch from transformers.utils.import_utils import is_torch_available from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin -if is_speech_available(): - from transformers import WhisperFeatureExtractor - if is_torch_available(): import torch @@ -53,8 +50,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None): return values -@require_torch -@require_torchaudio class WhisperFeatureExtractionTester(unittest.TestCase): def __init__( self, @@ -111,10 +106,8 @@ class WhisperFeatureExtractionTester(unittest.TestCase): return speech_inputs -@require_torch -@require_torchaudio class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): - feature_extraction_class = WhisperFeatureExtractor if is_speech_available() else None + feature_extraction_class = WhisperFeatureExtractor def setUp(self): self.feat_extract_tester = WhisperFeatureExtractionTester(self) @@ -193,6 +186,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + @require_torch def test_double_precision_pad(self): import torch @@ -213,7 +207,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. return [x["array"] for x in speech_samples] - def test_integration(self): + @require_torch + def test_torch_integration(self): # fmt: off EXPECTED_INPUT_FEATURES = torch.tensor( [ @@ -231,6 +226,25 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. self.assertEqual(input_features.shape, (1, 80, 3000)) self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4)) + @unittest.mock.patch("transformers.models.whisper.feature_extraction_whisper.is_torch_available", lambda: False) + def test_numpy_integration(self): + # fmt: off + EXPECTED_INPUT_FEATURES = np.array( + [ + 0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951, + 0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678, + 0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554, + -0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854 + ] + ) + # fmt: on + + input_speech = self._load_datasamples(1) + feature_extractor = WhisperFeatureExtractor() + input_features = feature_extractor(input_speech, return_tensors="np").input_features + self.assertEqual(input_features.shape, (1, 80, 3000)) + self.assertTrue(np.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4)) + def test_zero_mean_unit_variance_normalization_trunc_np_longest(self): feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) audio = self._load_datasamples(1)[0]