From 9a30753485653697c7db79e12b0cb2b8872c94c6 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Thu, 21 Sep 2023 17:52:47 +0200 Subject: [PATCH] Porting the torchaudio kaldi fbank implementation to audio_utils (#26182) * add kaldi fbank * make style * add herz_to_mel_kaldi tests * add mel to hertz kaldi test * integration tests * correct test and remove comment * make style * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change parameter name * Apply suggestions from Arthur review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update remove_dc_offset description * fix bug + make style * fix error in using np.exp instead of np.power * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/audio_utils.py | 46 ++++++++++---- tests/utils/test_audio_utils.py | 105 ++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 11 deletions(-) diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index a34892af41..5819f0723f 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -30,17 +30,19 @@ def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio freq (`float` or `np.ndarray`): The frequency, or multiple frequencies, in hertz (Hz). mel_scale (`str`, *optional*, defaults to `"htk"`): - The mel frequency scale to use, `"htk"` or `"slaney"`. + The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`. Returns: `float` or `np.ndarray`: The frequencies on the mel scale. """ - if mel_scale not in ["slaney", "htk"]: - raise ValueError('mel_scale should be one of "htk" or "slaney".') + if mel_scale not in ["slaney", "htk", "kaldi"]: + raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".') if mel_scale == "htk": return 2595.0 * np.log10(1.0 + (freq / 700.0)) + elif mel_scale == "kaldi": + return 1127.0 * np.log(1.0 + (freq / 700.0)) min_log_hertz = 1000.0 min_log_mel = 15.0 @@ -64,17 +66,19 @@ def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio mels (`float` or `np.ndarray`): The frequency, or multiple frequencies, in mels. mel_scale (`str`, *optional*, `"htk"`): - The mel frequency scale to use, `"htk"` or `"slaney"`. + The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`. Returns: `float` or `np.ndarray`: The frequencies in hertz. """ - if mel_scale not in ["slaney", "htk"]: - raise ValueError('mel_scale should be one of "htk" or "slaney".') + if mel_scale not in ["slaney", "htk", "kaldi"]: + raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".') if mel_scale == "htk": - return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) + return 700.0 * (np.power(10, mels / 2595.0) - 1.0) + elif mel_scale == "kaldi": + return 700.0 * (np.exp(mels / 1127.0) - 1.0) min_log_hertz = 1000.0 min_log_mel = 15.0 @@ -120,6 +124,7 @@ def mel_filter_bank( sampling_rate: int, norm: Optional[str] = None, mel_scale: str = "htk", + triangularize_in_mel_space: bool = False, ) -> np.ndarray: """ Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and @@ -155,7 +160,10 @@ def mel_filter_bank( norm (`str`, *optional*): If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization). mel_scale (`str`, *optional*, defaults to `"htk"`): - The mel frequency scale to use, `"htk"` or `"slaney"`. + The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`. + triangularize_in_mel_space (`bool`, *optional*, defaults to `False`): + If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This + should be set to `true` in order to get the same results as `torchaudio` when computing mel filters. Returns: `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a @@ -164,15 +172,21 @@ def mel_filter_bank( if norm is not None and norm != "slaney": raise ValueError('norm must be one of None or "slaney"') - # frequencies of FFT bins in Hz - fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins) - # center points of the triangular mel filters mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale) mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale) mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2) filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale) + if triangularize_in_mel_space: + # frequencies of FFT bins in Hz, but filters triangularized in mel space + fft_bin_width = sampling_rate / (num_frequency_bins * 2) + fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale) + filter_freqs = mel_freqs + else: + # frequencies of FFT bins in Hz + fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins) + mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs) if norm is not None and norm == "slaney": @@ -218,6 +232,7 @@ def window_function( - `"boxcar"`: a rectangular window - `"hamming"`: the Hamming window - `"hann"`: the Hann window + - `"povey"`: the Povey window Args: window_length (`int`): @@ -243,6 +258,8 @@ def window_function( window = np.hamming(length) elif name in ["hann", "hann_window"]: window = np.hanning(length) + elif name in ["povey"]: + window = np.power(np.hanning(length), 0.85) else: raise ValueError(f"Unknown window function '{name}'") @@ -281,6 +298,7 @@ def spectrogram( reference: float = 1.0, min_value: float = 1e-10, db_range: Optional[float] = None, + remove_dc_offset: Optional[bool] = None, dtype: np.dtype = np.float32, ) -> np.ndarray: """ @@ -363,6 +381,9 @@ def spectrogram( db_range (`float`, *optional*): Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the peak value and the smallest value will never be more than 80 dB. Must be greater than zero. + remove_dc_offset (`bool`, *optional*): + Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in + order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters. dtype (`np.dtype`, *optional*, defaults to `np.float32`): Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be `np.complex64`. @@ -414,6 +435,9 @@ def spectrogram( for frame_idx in range(num_frames): buffer[:frame_length] = waveform[timestep : timestep + frame_length] + if remove_dc_offset: + buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean() + if preemphasis is not None: buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1] buffer[0] *= 1 - preemphasis diff --git a/tests/utils/test_audio_utils.py b/tests/utils/test_audio_utils.py index f0333113ea..12d00929a9 100644 --- a/tests/utils/test_audio_utils.py +++ b/tests/utils/test_audio_utils.py @@ -45,6 +45,10 @@ class AudioUtilsFunctionTester(unittest.TestCase): expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016]) self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected)) + inputs = np.array([60, 100, 200, 1000, 1001, 2000]) + expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674]) + self.assertTrue(np.allclose(hertz_to_mel(inputs, "kaldi"), expected)) + with pytest.raises(ValueError): hertz_to_mel(100, mel_scale=None) @@ -63,6 +67,10 @@ class AudioUtilsFunctionTester(unittest.TestCase): expected = np.array([60, 100, 200, 1000, 1001, 2000]) self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected)) + inputs = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674]) + expected = np.array([60, 100, 200, 1000, 1001, 2000]) + self.assertTrue(np.allclose(mel_to_hertz(inputs, "kaldi"), expected)) + with pytest.raises(ValueError): mel_to_hertz(100, mel_scale=None) @@ -89,6 +97,18 @@ class AudioUtilsFunctionTester(unittest.TestCase): ) self.assertEqual(mel_filters.shape, (513, 13)) + mel_filters = mel_filter_bank( + num_frequency_bins=513, + num_mel_filters=13, + min_frequency=100, + max_frequency=4000, + sampling_rate=16000, + norm="slaney", + mel_scale="slaney", + triangularize_in_mel_space=True, + ) + self.assertEqual(mel_filters.shape, (513, 13)) + def test_mel_filter_bank_htk(self): mel_filters = mel_filter_bank( num_frequency_bins=16, @@ -153,6 +173,39 @@ class AudioUtilsFunctionTester(unittest.TestCase): # fmt: on self.assertTrue(np.allclose(mel_filters, expected)) + def test_mel_filter_bank_kaldi(self): + mel_filters = mel_filter_bank( + num_frequency_bins=16, + num_mel_filters=4, + min_frequency=0, + max_frequency=2000, + sampling_rate=4000, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + # fmt: off + expected = np.array( + [[0.0000, 0.0000, 0.0000, 0.0000], + [0.6086, 0.0000, 0.0000, 0.0000], + [0.8689, 0.1311, 0.0000, 0.0000], + [0.4110, 0.5890, 0.0000, 0.0000], + [0.0036, 0.9964, 0.0000, 0.0000], + [0.0000, 0.6366, 0.3634, 0.0000], + [0.0000, 0.3027, 0.6973, 0.0000], + [0.0000, 0.0000, 0.9964, 0.0036], + [0.0000, 0.0000, 0.7135, 0.2865], + [0.0000, 0.0000, 0.4507, 0.5493], + [0.0000, 0.0000, 0.2053, 0.7947], + [0.0000, 0.0000, 0.0000, 0.9752], + [0.0000, 0.0000, 0.0000, 0.7585], + [0.0000, 0.0000, 0.0000, 0.5539], + [0.0000, 0.0000, 0.0000, 0.3599], + [0.0000, 0.0000, 0.0000, 0.1756]] + ) + # fmt: on + self.assertTrue(np.allclose(mel_filters, expected, atol=5e-5)) + def test_mel_filter_bank_slaney_norm(self): mel_filters = mel_filter_bank( num_frequency_bins=16, @@ -271,6 +324,58 @@ class AudioUtilsFunctionTester(unittest.TestCase): self.assertEqual(spec.shape, (257, 732)) self.assertTrue(np.allclose(spec[:64, 400], expected)) + mel_filters = mel_filter_bank( + num_frequency_bins=256, + num_mel_filters=400, + min_frequency=20, + max_frequency=8000, + sampling_rate=16000, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + + mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) + + spec = spectrogram( + waveform, + window_function(400, "povey", periodic=False), + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + pad_mode="reflect", + onesided=True, + preemphasis=0.97, + mel_filters=mel_filters, + log_mel="log", + mel_floor=1.1920928955078125e-07, + remove_dc_offset=True, + ) + self.assertEqual(spec.shape, (400, 584)) + + # fmt: off + expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515, + -15.94238515, -15.94238515, -15.94238515, -15.94238515, + -6.52463769, -7.73677889, -15.94238515, -15.94238515, + -15.94238515, -15.94238515, -4.18650018, -3.37195286, + -15.94238515, -15.94238515, -15.94238515, -15.94238515, + -4.70190154, -2.4217066 , -15.94238515, -15.94238515, + -15.94238515, -15.94238515, -5.62755239, -3.53385194, + -15.94238515, -15.94238515, -15.94238515, -15.94238515, + -9.43303023, -8.77480925, -15.94238515, -15.94238515, + -15.94238515, -15.94238515, -4.2951092 , -5.51585994, + -15.94238515, -15.94238515, -15.94238515, -4.40151721, + -3.95228878, -15.94238515, -15.94238515, -15.94238515, + -6.10365415, -4.59494697, -15.94238515, -15.94238515, + -15.94238515, -8.10727767, -6.2585298 , -15.94238515, + -15.94238515, -15.94238515, -5.60161702, -4.47217004, + -15.94238515, -15.94238515, -15.94238515, -5.91641988] + ) + # fmt: on + self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5)) + def test_spectrogram_center_padding(self): waveform = self._load_datasamples(1)[0]