From fb8e6c50e4a3a4bb1011e286d9bd26d79cd6334d Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 27 Mar 2025 15:20:02 +0100 Subject: [PATCH] [audio utils] fix fft_bin_width computation (#36603) * fix fft_bin_width computation * update docstring + enforce correct params * update test with correct value * udpate test * update feature extractors for concerned models * update * make * udpate docstring * udpate docstring --- src/transformers/audio_utils.py | 10 +++- ...xtraction_audio_spectrogram_transformer.py | 4 +- .../feature_extraction_seamless_m4t.py | 4 +- .../feature_extraction_speech_to_text.py | 4 +- tests/utils/test_audio_utils.py | 54 +++++++++++-------- 5 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index 2894bc800e..8420a84e08 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -293,7 +293,7 @@ def mel_filter_bank( Args: num_frequency_bins (`int`): - Number of frequencies used to compute the spectrogram (should be the same as in `stft`). + Number of frequency bins (should be the same as `n_fft // 2 + 1` where `n_fft` is the size of the Fourier Transform used to compute the spectrogram). num_mel_filters (`int`): Number of mel filters to generate. min_frequency (`float`): @@ -317,6 +317,12 @@ def mel_filter_bank( if norm is not None and norm != "slaney": raise ValueError('norm must be one of None or "slaney"') + if num_frequency_bins < 2: + raise ValueError(f"Require num_frequency_bins: {num_frequency_bins} >= 2") + + if min_frequency > max_frequency: + raise ValueError(f"Require min_frequency: {min_frequency} <= max_frequency: {max_frequency}") + # center points of the triangular mel filters mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale) mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale) @@ -325,7 +331,7 @@ def mel_filter_bank( if triangularize_in_mel_space: # frequencies of FFT bins in Hz, but filters triangularized in mel space - fft_bin_width = sampling_rate / (num_frequency_bins * 2) + fft_bin_width = sampling_rate / ((num_frequency_bins - 1) * 2) fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale) filter_freqs = mel_freqs else: diff --git a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py index 7da6d94bf8..888d38e187 100644 --- a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py @@ -91,7 +91,7 @@ class ASTFeatureExtractor(SequenceFeatureExtractor): if not is_speech_available(): mel_filters = mel_filter_bank( - num_frequency_bins=256, + num_frequency_bins=257, num_mel_filters=self.num_mel_bins, min_frequency=20, max_frequency=sampling_rate // 2, @@ -101,7 +101,7 @@ class ASTFeatureExtractor(SequenceFeatureExtractor): triangularize_in_mel_space=True, ) - self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) + self.mel_filters = mel_filters self.window = window_function(400, "hann", periodic=False) def _extract_fbank_features( diff --git a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py index 84b47cc998..b17dcf792e 100644 --- a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py @@ -74,7 +74,7 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): self.stride = stride mel_filters = mel_filter_bank( - num_frequency_bins=256, + num_frequency_bins=257, num_mel_filters=self.num_mel_bins, min_frequency=20, max_frequency=sampling_rate // 2, @@ -84,7 +84,7 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): triangularize_in_mel_space=True, ) - self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) + self.mel_filters = mel_filters self.window = window_function(400, "povey", periodic=False) super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py index 5473c16681..3abbfacb8d 100644 --- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py +++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py @@ -91,7 +91,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): if not is_speech_available(): mel_filters = mel_filter_bank( - num_frequency_bins=256, + num_frequency_bins=257, num_mel_filters=self.num_mel_bins, min_frequency=20, max_frequency=sampling_rate // 2, @@ -101,7 +101,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): triangularize_in_mel_space=True, ) - self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) + self.mel_filters = mel_filters self.window = window_function(400, "povey", periodic=False) def _extract_fbank_features( diff --git a/tests/utils/test_audio_utils.py b/tests/utils/test_audio_utils.py index 3e417bf7e3..9ece033d06 100644 --- a/tests/utils/test_audio_utils.py +++ b/tests/utils/test_audio_utils.py @@ -194,26 +194,38 @@ class AudioUtilsFunctionTester(unittest.TestCase): triangularize_in_mel_space=True, ) # fmt: off + # here the expected values from torchaudio.compliance.kaldi.get_mel_banks + # note that we compute values in float64 while they do it in float32 expected = np.array( - [[0.0000, 0.0000, 0.0000, 0.0000], - [0.6086, 0.0000, 0.0000, 0.0000], - [0.8689, 0.1311, 0.0000, 0.0000], - [0.4110, 0.5890, 0.0000, 0.0000], - [0.0036, 0.9964, 0.0000, 0.0000], - [0.0000, 0.6366, 0.3634, 0.0000], - [0.0000, 0.3027, 0.6973, 0.0000], - [0.0000, 0.0000, 0.9964, 0.0036], - [0.0000, 0.0000, 0.7135, 0.2865], - [0.0000, 0.0000, 0.4507, 0.5493], - [0.0000, 0.0000, 0.2053, 0.7947], - [0.0000, 0.0000, 0.0000, 0.9752], - [0.0000, 0.0000, 0.0000, 0.7585], - [0.0000, 0.0000, 0.0000, 0.5539], - [0.0000, 0.0000, 0.0000, 0.3599], - [0.0000, 0.0000, 0.0000, 0.1756]] + [ + [0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.0000000000000000], + [0.6457883715629578, 0.0000000000000000, 0.0000000000000000, 0.0000000000000000], + [0.8044781088829041, 0.1955219060182571, 0.0000000000000000, 0.0000000000000000], + [0.3258901536464691, 0.6741098165512085, 0.0000000000000000, 0.0000000000000000], + [0.0000000000000000, 0.9021250009536743, 0.0978749766945839, 0.0000000000000000], + [0.0000000000000000, 0.5219038724899292, 0.4780961275100708, 0.0000000000000000], + [0.0000000000000000, 0.1771058291196823, 0.8228941559791565, 0.0000000000000000], + [0.0000000000000000, 0.0000000000000000, 0.8616894483566284, 0.1383105516433716], + [0.0000000000000000, 0.0000000000000000, 0.5710380673408508, 0.4289619624614716], + [0.0000000000000000, 0.0000000000000000, 0.3015440106391907, 0.6984559893608093], + [0.0000000000000000, 0.0000000000000000, 0.0503356307744980, 0.9496643543243408], + [0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.8150880336761475], + [0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.5938932299613953], + [0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.3851676583290100], + [0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.1875794380903244], + ], + dtype=np.float64, ) # fmt: on - self.assertTrue(np.allclose(mel_filters, expected, atol=5e-5)) + + # kaldi implementation does not compute values for last fft bin + # indeed, they enforce max_frequency <= sampling_rate / 2 and + # therefore they know that last fft bin filter bank values will be all 0 + # and pad after with zeros + # to comply with our API for `mel_filter_bank`, we need to also pad here + expected = np.pad(expected, ((0, 1), (0, 0))) + + self.assertTrue(np.allclose(mel_filters, expected)) def test_mel_filter_bank_slaney_norm(self): mel_filters = mel_filter_bank( @@ -369,7 +381,7 @@ class AudioUtilsFunctionTester(unittest.TestCase): self.assertTrue(np.allclose(spec[:64, 400], expected)) mel_filters = mel_filter_bank( - num_frequency_bins=256, + num_frequency_bins=257, num_mel_filters=400, min_frequency=20, max_frequency=8000, @@ -379,8 +391,6 @@ class AudioUtilsFunctionTester(unittest.TestCase): triangularize_in_mel_space=True, ) - mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) - spec = spectrogram( waveform, window_function(400, "povey", periodic=False), @@ -510,7 +520,7 @@ class AudioUtilsFunctionTester(unittest.TestCase): self.assertTrue(np.allclose(spec_list[2][:64, 400], expected3)) mel_filters = mel_filter_bank( - num_frequency_bins=256, + num_frequency_bins=257, num_mel_filters=400, min_frequency=20, max_frequency=8000, @@ -520,8 +530,6 @@ class AudioUtilsFunctionTester(unittest.TestCase): triangularize_in_mel_space=True, ) - mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) - spec_list = spectrogram_batch( waveform_list, window_function(400, "povey", periodic=False),