From 1ed93be48a8f0fc77fce4dacced1976fa5d55713 Mon Sep 17 00:00:00 2001 From: vaibhavagg303 <89418214+vaibhavagg303@users.noreply.github.com> Date: Mon, 8 Apr 2024 14:06:25 +0530 Subject: [PATCH] [Whisper] Computing features on GPU in batch mode for whisper feature extractor. (#29900) * add _torch_extract_fbank_features_batch function in feature_extractor_whisper * reformat feature_extraction_whisper.py file * handle batching in single function * add gpu test & doc * add batch test & device in each __call__ * add device arg in doc string --------- Co-authored-by: vaibhav.aggarwal --- .../whisper/feature_extraction_whisper.py | 62 +++++++++++++------ .../test_feature_extraction_whisper.py | 38 +++++++++++- 2 files changed, 81 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 42104c3293..508e85b91f 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -94,41 +94,63 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): mel_scale="slaney", ) - def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray: + def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray: """ Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch implementation with 1e-5 tolerance. """ - log_spec = spectrogram( - waveform, - window_function(self.n_fft, "hann"), - frame_length=self.n_fft, - hop_length=self.hop_length, - power=2.0, - mel_filters=self.mel_filters, - log_mel="log10", - ) - log_spec = log_spec[:, :-1] - log_spec = np.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - return log_spec + if device != "cpu": + raise ValueError( + f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator " + "devices requires torch, which is not installed. Either set `device='cpu'`, or " + "install torch according to the official instructions: https://pytorch.org/get-started/locally/" + ) + log_spec_batch = [] + for waveform in waveform_batch: + log_spec = spectrogram( + waveform, + window_function(self.n_fft, "hann"), + frame_length=self.n_fft, + hop_length=self.hop_length, + power=2.0, + mel_filters=self.mel_filters, + log_mel="log10", + ) + log_spec = log_spec[:, :-1] + log_spec = np.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + log_spec_batch.append(log_spec) + log_spec_batch = np.array(log_spec_batch) + return log_spec_batch - def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray: + def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray: """ - Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation. + Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching, + yielding results similar to cpu computing with 1e-5 tolerance. """ waveform = torch.from_numpy(waveform).type(torch.float32) window = torch.hann_window(self.n_fft) + if device != "cpu": + waveform = waveform.to(device) + window = window.to(device) stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) magnitudes = stft[..., :-1].abs() ** 2 mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32) + if device != "cpu": + mel_filters = mel_filters.to(device) mel_spec = mel_filters.T @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + if waveform.dim() == 2: + max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] + log_spec = torch.maximum(log_spec, max_val - 8.0) + else: + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 + if device != "cpu": + log_spec = log_spec.detach().cpu() return log_spec.numpy() @staticmethod @@ -165,6 +187,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): max_length: Optional[int] = None, sampling_rate: Optional[int] = None, do_normalize: Optional[bool] = None, + device: Optional[str] = "cpu", **kwargs, ) -> BatchFeature: """ @@ -211,6 +234,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): do_normalize (`bool`, *optional*, defaults to `False`): Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly improve the performance of the model. + device (`str`, *optional*, defaults to `'cpu'`): + Specifies the device for computation of the log-mel spectrogram of audio signals in the + `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda") """ if sampling_rate is not None: @@ -272,7 +298,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): extract_fbank_features = ( self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features ) - input_features = [extract_fbank_features(waveform) for waveform in input_features[0]] + input_features = extract_fbank_features(input_features[0], device) if isinstance(input_features[0], List): padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index 77c7a9be3d..8b1e25927e 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -24,7 +24,7 @@ import numpy as np from datasets import load_dataset from transformers import WhisperFeatureExtractor -from transformers.testing_utils import check_json_file_has_correct_format, require_torch +from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torch_gpu from transformers.utils.import_utils import is_torch_available from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin @@ -207,6 +207,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. return [x["array"] for x in speech_samples] + @require_torch_gpu @require_torch def test_torch_integration(self): # fmt: off @@ -223,6 +224,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. input_speech = self._load_datasamples(1) feature_extractor = WhisperFeatureExtractor() input_features = feature_extractor(input_speech, return_tensors="pt").input_features + self.assertEqual(input_features.shape, (1, 80, 3000)) self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4)) @@ -253,3 +255,37 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. self.assertTrue(np.all(np.mean(audio) < 1e-3)) self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3)) + + @require_torch_gpu + @require_torch + def test_torch_integration_batch(self): + # fmt: off + EXPECTED_INPUT_FEATURES = torch.tensor( + [ + [ + 0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951, + 0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678, + 0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554, + -0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854 + ], + [ + -0.4696, -0.0751, 0.0276, -0.0312, -0.0540, -0.0383, 0.1295, 0.0568, + -0.2071, -0.0548, 0.0389, -0.0316, -0.2346, -0.1068, -0.0322, 0.0475, + -0.1709, -0.0041, 0.0872, 0.0537, 0.0075, -0.0392, 0.0371, 0.0189, + -0.1522, -0.0270, 0.0744, 0.0738, -0.0245, -0.0667 + ], + [ + -0.2337, -0.0060, -0.0063, -0.2353, -0.0431, 0.1102, -0.1492, -0.0292, + 0.0787, -0.0608, 0.0143, 0.0582, 0.0072, 0.0101, -0.0444, -0.1701, + -0.0064, -0.0027, -0.0826, -0.0730, -0.0099, -0.0762, -0.0170, 0.0446, + -0.1153, 0.0960, -0.0361, 0.0652, 0.1207, 0.0277 + ] + ] + ) + # fmt: on + + input_speech = self._load_datasamples(3) + feature_extractor = WhisperFeatureExtractor() + input_features = feature_extractor(input_speech, return_tensors="pt").input_features + self.assertEqual(input_features.shape, (3, 80, 3000)) + self.assertTrue(torch.allclose(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))