[Whisper] Use torch for stft if available (#26119)
* [Whisper] Use torch for stft if available * update docstring * mock patch decorator * fit on one line
This commit is contained in:
@@ -19,12 +19,16 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ... import is_torch_available
|
||||||
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
from ...audio_utils import mel_filter_bank, spectrogram, window_function
|
||||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...utils import TensorType, logging
|
from ...utils import TensorType, logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -109,6 +113,24 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
log_spec = (log_spec + 4.0) / 4.0
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
return log_spec
|
return log_spec
|
||||||
|
|
||||||
|
def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation.
|
||||||
|
"""
|
||||||
|
waveform = torch.from_numpy(waveform).type(torch.float32)
|
||||||
|
|
||||||
|
window = torch.hann_window(self.n_fft)
|
||||||
|
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
|
||||||
|
magnitudes = stft[..., :-1].abs() ** 2
|
||||||
|
|
||||||
|
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
||||||
|
mel_spec = mel_filters.T @ magnitudes
|
||||||
|
|
||||||
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||||
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
|
return log_spec.numpy()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
|
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
|
||||||
def zero_mean_unit_var_norm(
|
def zero_mean_unit_var_norm(
|
||||||
@@ -146,7 +168,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
Main method to featurize and prepare for the model one or several sequence(s).
|
Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
|
||||||
|
the STFT computation if available, otherwise a slower NumPy based one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
||||||
@@ -246,7 +269,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
# make sure list is in array format
|
# make sure list is in array format
|
||||||
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
|
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
|
||||||
|
|
||||||
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]]
|
extract_fbank_features = (
|
||||||
|
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
|
||||||
|
)
|
||||||
|
input_features = [extract_fbank_features(waveform) for waveform in input_features[0]]
|
||||||
|
|
||||||
if isinstance(input_features[0], List):
|
if isinstance(input_features[0], List):
|
||||||
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
||||||
|
|||||||
@@ -23,16 +23,13 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import is_speech_available
|
from transformers import WhisperFeatureExtractor
|
||||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
|
||||||
from transformers.utils.import_utils import is_torch_available
|
from transformers.utils.import_utils import is_torch_available
|
||||||
|
|
||||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||||
|
|
||||||
|
|
||||||
if is_speech_available():
|
|
||||||
from transformers import WhisperFeatureExtractor
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -53,8 +50,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
@require_torchaudio
|
|
||||||
class WhisperFeatureExtractionTester(unittest.TestCase):
|
class WhisperFeatureExtractionTester(unittest.TestCase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -111,10 +106,8 @@ class WhisperFeatureExtractionTester(unittest.TestCase):
|
|||||||
return speech_inputs
|
return speech_inputs
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
@require_torchaudio
|
|
||||||
class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||||
feature_extraction_class = WhisperFeatureExtractor if is_speech_available() else None
|
feature_extraction_class = WhisperFeatureExtractor
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.feat_extract_tester = WhisperFeatureExtractionTester(self)
|
self.feat_extract_tester = WhisperFeatureExtractionTester(self)
|
||||||
@@ -193,6 +186,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
|||||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
def test_double_precision_pad(self):
|
def test_double_precision_pad(self):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -213,7 +207,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
|||||||
|
|
||||||
return [x["array"] for x in speech_samples]
|
return [x["array"] for x in speech_samples]
|
||||||
|
|
||||||
def test_integration(self):
|
@require_torch
|
||||||
|
def test_torch_integration(self):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_INPUT_FEATURES = torch.tensor(
|
EXPECTED_INPUT_FEATURES = torch.tensor(
|
||||||
[
|
[
|
||||||
@@ -231,6 +226,25 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
|||||||
self.assertEqual(input_features.shape, (1, 80, 3000))
|
self.assertEqual(input_features.shape, (1, 80, 3000))
|
||||||
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||||
|
|
||||||
|
@unittest.mock.patch("transformers.models.whisper.feature_extraction_whisper.is_torch_available", lambda: False)
|
||||||
|
def test_numpy_integration(self):
|
||||||
|
# fmt: off
|
||||||
|
EXPECTED_INPUT_FEATURES = np.array(
|
||||||
|
[
|
||||||
|
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
|
||||||
|
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
|
||||||
|
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
|
||||||
|
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
input_speech = self._load_datasamples(1)
|
||||||
|
feature_extractor = WhisperFeatureExtractor()
|
||||||
|
input_features = feature_extractor(input_speech, return_tensors="np").input_features
|
||||||
|
self.assertEqual(input_features.shape, (1, 80, 3000))
|
||||||
|
self.assertTrue(np.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||||
|
|
||||||
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
|
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
|
||||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
audio = self._load_datasamples(1)[0]
|
audio = self._load_datasamples(1)[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user