Porting the torchaudio kaldi fbank implementation to audio_utils (#26182)

* add kaldi fbank

* make style

* add herz_to_mel_kaldi tests

* add mel to hertz kaldi test

* integration tests

* correct test and remove comment

* make style

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* change parameter name

* Apply suggestions from Arthur review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update remove_dc_offset description

* fix bug  + make style

* fix error in using np.exp instead of np.power

* make style

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe
2023-09-21 17:52:47 +02:00
committed by GitHub
parent b132c1703e
commit 9a30753485
2 changed files with 140 additions and 11 deletions

View File

@@ -30,17 +30,19 @@ def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio
freq (`float` or `np.ndarray`): freq (`float` or `np.ndarray`):
The frequency, or multiple frequencies, in hertz (Hz). The frequency, or multiple frequencies, in hertz (Hz).
mel_scale (`str`, *optional*, defaults to `"htk"`): mel_scale (`str`, *optional*, defaults to `"htk"`):
The mel frequency scale to use, `"htk"` or `"slaney"`. The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
Returns: Returns:
`float` or `np.ndarray`: The frequencies on the mel scale. `float` or `np.ndarray`: The frequencies on the mel scale.
""" """
if mel_scale not in ["slaney", "htk"]: if mel_scale not in ["slaney", "htk", "kaldi"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".') raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
if mel_scale == "htk": if mel_scale == "htk":
return 2595.0 * np.log10(1.0 + (freq / 700.0)) return 2595.0 * np.log10(1.0 + (freq / 700.0))
elif mel_scale == "kaldi":
return 1127.0 * np.log(1.0 + (freq / 700.0))
min_log_hertz = 1000.0 min_log_hertz = 1000.0
min_log_mel = 15.0 min_log_mel = 15.0
@@ -64,17 +66,19 @@ def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio
mels (`float` or `np.ndarray`): mels (`float` or `np.ndarray`):
The frequency, or multiple frequencies, in mels. The frequency, or multiple frequencies, in mels.
mel_scale (`str`, *optional*, `"htk"`): mel_scale (`str`, *optional*, `"htk"`):
The mel frequency scale to use, `"htk"` or `"slaney"`. The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
Returns: Returns:
`float` or `np.ndarray`: The frequencies in hertz. `float` or `np.ndarray`: The frequencies in hertz.
""" """
if mel_scale not in ["slaney", "htk"]: if mel_scale not in ["slaney", "htk", "kaldi"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".') raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
if mel_scale == "htk": if mel_scale == "htk":
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
elif mel_scale == "kaldi":
return 700.0 * (np.exp(mels / 1127.0) - 1.0)
min_log_hertz = 1000.0 min_log_hertz = 1000.0
min_log_mel = 15.0 min_log_mel = 15.0
@@ -120,6 +124,7 @@ def mel_filter_bank(
sampling_rate: int, sampling_rate: int,
norm: Optional[str] = None, norm: Optional[str] = None,
mel_scale: str = "htk", mel_scale: str = "htk",
triangularize_in_mel_space: bool = False,
) -> np.ndarray: ) -> np.ndarray:
""" """
Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
@@ -155,7 +160,10 @@ def mel_filter_bank(
norm (`str`, *optional*): norm (`str`, *optional*):
If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization). If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
mel_scale (`str`, *optional*, defaults to `"htk"`): mel_scale (`str`, *optional*, defaults to `"htk"`):
The mel frequency scale to use, `"htk"` or `"slaney"`. The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
Returns: Returns:
`np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
@@ -164,15 +172,21 @@ def mel_filter_bank(
if norm is not None and norm != "slaney": if norm is not None and norm != "slaney":
raise ValueError('norm must be one of None or "slaney"') raise ValueError('norm must be one of None or "slaney"')
# frequencies of FFT bins in Hz
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
# center points of the triangular mel filters # center points of the triangular mel filters
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale) mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale) mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2) mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale) filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
if triangularize_in_mel_space:
# frequencies of FFT bins in Hz, but filters triangularized in mel space
fft_bin_width = sampling_rate / (num_frequency_bins * 2)
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
filter_freqs = mel_freqs
else:
# frequencies of FFT bins in Hz
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs) mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
if norm is not None and norm == "slaney": if norm is not None and norm == "slaney":
@@ -218,6 +232,7 @@ def window_function(
- `"boxcar"`: a rectangular window - `"boxcar"`: a rectangular window
- `"hamming"`: the Hamming window - `"hamming"`: the Hamming window
- `"hann"`: the Hann window - `"hann"`: the Hann window
- `"povey"`: the Povey window
Args: Args:
window_length (`int`): window_length (`int`):
@@ -243,6 +258,8 @@ def window_function(
window = np.hamming(length) window = np.hamming(length)
elif name in ["hann", "hann_window"]: elif name in ["hann", "hann_window"]:
window = np.hanning(length) window = np.hanning(length)
elif name in ["povey"]:
window = np.power(np.hanning(length), 0.85)
else: else:
raise ValueError(f"Unknown window function '{name}'") raise ValueError(f"Unknown window function '{name}'")
@@ -281,6 +298,7 @@ def spectrogram(
reference: float = 1.0, reference: float = 1.0,
min_value: float = 1e-10, min_value: float = 1e-10,
db_range: Optional[float] = None, db_range: Optional[float] = None,
remove_dc_offset: Optional[bool] = None,
dtype: np.dtype = np.float32, dtype: np.dtype = np.float32,
) -> np.ndarray: ) -> np.ndarray:
""" """
@@ -363,6 +381,9 @@ def spectrogram(
db_range (`float`, *optional*): db_range (`float`, *optional*):
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
peak value and the smallest value will never be more than 80 dB. Must be greater than zero. peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
remove_dc_offset (`bool`, *optional*):
Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
dtype (`np.dtype`, *optional*, defaults to `np.float32`): dtype (`np.dtype`, *optional*, defaults to `np.float32`):
Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
`np.complex64`. `np.complex64`.
@@ -414,6 +435,9 @@ def spectrogram(
for frame_idx in range(num_frames): for frame_idx in range(num_frames):
buffer[:frame_length] = waveform[timestep : timestep + frame_length] buffer[:frame_length] = waveform[timestep : timestep + frame_length]
if remove_dc_offset:
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
if preemphasis is not None: if preemphasis is not None:
buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1] buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
buffer[0] *= 1 - preemphasis buffer[0] *= 1 - preemphasis

View File

@@ -45,6 +45,10 @@ class AudioUtilsFunctionTester(unittest.TestCase):
expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016]) expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected)) self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected))
inputs = np.array([60, 100, 200, 1000, 1001, 2000])
expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
self.assertTrue(np.allclose(hertz_to_mel(inputs, "kaldi"), expected))
with pytest.raises(ValueError): with pytest.raises(ValueError):
hertz_to_mel(100, mel_scale=None) hertz_to_mel(100, mel_scale=None)
@@ -63,6 +67,10 @@ class AudioUtilsFunctionTester(unittest.TestCase):
expected = np.array([60, 100, 200, 1000, 1001, 2000]) expected = np.array([60, 100, 200, 1000, 1001, 2000])
self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected)) self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected))
inputs = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
expected = np.array([60, 100, 200, 1000, 1001, 2000])
self.assertTrue(np.allclose(mel_to_hertz(inputs, "kaldi"), expected))
with pytest.raises(ValueError): with pytest.raises(ValueError):
mel_to_hertz(100, mel_scale=None) mel_to_hertz(100, mel_scale=None)
@@ -89,6 +97,18 @@ class AudioUtilsFunctionTester(unittest.TestCase):
) )
self.assertEqual(mel_filters.shape, (513, 13)) self.assertEqual(mel_filters.shape, (513, 13))
mel_filters = mel_filter_bank(
num_frequency_bins=513,
num_mel_filters=13,
min_frequency=100,
max_frequency=4000,
sampling_rate=16000,
norm="slaney",
mel_scale="slaney",
triangularize_in_mel_space=True,
)
self.assertEqual(mel_filters.shape, (513, 13))
def test_mel_filter_bank_htk(self): def test_mel_filter_bank_htk(self):
mel_filters = mel_filter_bank( mel_filters = mel_filter_bank(
num_frequency_bins=16, num_frequency_bins=16,
@@ -153,6 +173,39 @@ class AudioUtilsFunctionTester(unittest.TestCase):
# fmt: on # fmt: on
self.assertTrue(np.allclose(mel_filters, expected)) self.assertTrue(np.allclose(mel_filters, expected))
def test_mel_filter_bank_kaldi(self):
mel_filters = mel_filter_bank(
num_frequency_bins=16,
num_mel_filters=4,
min_frequency=0,
max_frequency=2000,
sampling_rate=4000,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
# fmt: off
expected = np.array(
[[0.0000, 0.0000, 0.0000, 0.0000],
[0.6086, 0.0000, 0.0000, 0.0000],
[0.8689, 0.1311, 0.0000, 0.0000],
[0.4110, 0.5890, 0.0000, 0.0000],
[0.0036, 0.9964, 0.0000, 0.0000],
[0.0000, 0.6366, 0.3634, 0.0000],
[0.0000, 0.3027, 0.6973, 0.0000],
[0.0000, 0.0000, 0.9964, 0.0036],
[0.0000, 0.0000, 0.7135, 0.2865],
[0.0000, 0.0000, 0.4507, 0.5493],
[0.0000, 0.0000, 0.2053, 0.7947],
[0.0000, 0.0000, 0.0000, 0.9752],
[0.0000, 0.0000, 0.0000, 0.7585],
[0.0000, 0.0000, 0.0000, 0.5539],
[0.0000, 0.0000, 0.0000, 0.3599],
[0.0000, 0.0000, 0.0000, 0.1756]]
)
# fmt: on
self.assertTrue(np.allclose(mel_filters, expected, atol=5e-5))
def test_mel_filter_bank_slaney_norm(self): def test_mel_filter_bank_slaney_norm(self):
mel_filters = mel_filter_bank( mel_filters = mel_filter_bank(
num_frequency_bins=16, num_frequency_bins=16,
@@ -271,6 +324,58 @@ class AudioUtilsFunctionTester(unittest.TestCase):
self.assertEqual(spec.shape, (257, 732)) self.assertEqual(spec.shape, (257, 732))
self.assertTrue(np.allclose(spec[:64, 400], expected)) self.assertTrue(np.allclose(spec[:64, 400], expected))
mel_filters = mel_filter_bank(
num_frequency_bins=256,
num_mel_filters=400,
min_frequency=20,
max_frequency=8000,
sampling_rate=16000,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
spec = spectrogram(
waveform,
window_function(400, "povey", periodic=False),
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
pad_mode="reflect",
onesided=True,
preemphasis=0.97,
mel_filters=mel_filters,
log_mel="log",
mel_floor=1.1920928955078125e-07,
remove_dc_offset=True,
)
self.assertEqual(spec.shape, (400, 584))
# fmt: off
expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
-6.52463769, -7.73677889, -15.94238515, -15.94238515,
-15.94238515, -15.94238515, -4.18650018, -3.37195286,
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
-4.70190154, -2.4217066 , -15.94238515, -15.94238515,
-15.94238515, -15.94238515, -5.62755239, -3.53385194,
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
-9.43303023, -8.77480925, -15.94238515, -15.94238515,
-15.94238515, -15.94238515, -4.2951092 , -5.51585994,
-15.94238515, -15.94238515, -15.94238515, -4.40151721,
-3.95228878, -15.94238515, -15.94238515, -15.94238515,
-6.10365415, -4.59494697, -15.94238515, -15.94238515,
-15.94238515, -8.10727767, -6.2585298 , -15.94238515,
-15.94238515, -15.94238515, -5.60161702, -4.47217004,
-15.94238515, -15.94238515, -15.94238515, -5.91641988]
)
# fmt: on
self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))
def test_spectrogram_center_padding(self): def test_spectrogram_center_padding(self):
waveform = self._load_datasamples(1)[0] waveform = self._load_datasamples(1)[0]