Porting the torchaudio kaldi fbank implementation to audio_utils (#26182)
* add kaldi fbank * make style * add herz_to_mel_kaldi tests * add mel to hertz kaldi test * integration tests * correct test and remove comment * make style * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change parameter name * Apply suggestions from Arthur review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update remove_dc_offset description * fix bug + make style * fix error in using np.exp instead of np.power * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -30,17 +30,19 @@ def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio
|
||||
freq (`float` or `np.ndarray`):
|
||||
The frequency, or multiple frequencies, in hertz (Hz).
|
||||
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
||||
The mel frequency scale to use, `"htk"` or `"slaney"`.
|
||||
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
||||
|
||||
Returns:
|
||||
`float` or `np.ndarray`: The frequencies on the mel scale.
|
||||
"""
|
||||
|
||||
if mel_scale not in ["slaney", "htk"]:
|
||||
raise ValueError('mel_scale should be one of "htk" or "slaney".')
|
||||
if mel_scale not in ["slaney", "htk", "kaldi"]:
|
||||
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
|
||||
|
||||
if mel_scale == "htk":
|
||||
return 2595.0 * np.log10(1.0 + (freq / 700.0))
|
||||
elif mel_scale == "kaldi":
|
||||
return 1127.0 * np.log(1.0 + (freq / 700.0))
|
||||
|
||||
min_log_hertz = 1000.0
|
||||
min_log_mel = 15.0
|
||||
@@ -64,17 +66,19 @@ def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio
|
||||
mels (`float` or `np.ndarray`):
|
||||
The frequency, or multiple frequencies, in mels.
|
||||
mel_scale (`str`, *optional*, `"htk"`):
|
||||
The mel frequency scale to use, `"htk"` or `"slaney"`.
|
||||
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
||||
|
||||
Returns:
|
||||
`float` or `np.ndarray`: The frequencies in hertz.
|
||||
"""
|
||||
|
||||
if mel_scale not in ["slaney", "htk"]:
|
||||
raise ValueError('mel_scale should be one of "htk" or "slaney".')
|
||||
if mel_scale not in ["slaney", "htk", "kaldi"]:
|
||||
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
|
||||
|
||||
if mel_scale == "htk":
|
||||
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
|
||||
return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
|
||||
elif mel_scale == "kaldi":
|
||||
return 700.0 * (np.exp(mels / 1127.0) - 1.0)
|
||||
|
||||
min_log_hertz = 1000.0
|
||||
min_log_mel = 15.0
|
||||
@@ -120,6 +124,7 @@ def mel_filter_bank(
|
||||
sampling_rate: int,
|
||||
norm: Optional[str] = None,
|
||||
mel_scale: str = "htk",
|
||||
triangularize_in_mel_space: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
|
||||
@@ -155,7 +160,10 @@ def mel_filter_bank(
|
||||
norm (`str`, *optional*):
|
||||
If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
|
||||
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
||||
The mel frequency scale to use, `"htk"` or `"slaney"`.
|
||||
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
||||
triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
|
||||
If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
|
||||
should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
|
||||
|
||||
Returns:
|
||||
`np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
|
||||
@@ -164,15 +172,21 @@ def mel_filter_bank(
|
||||
if norm is not None and norm != "slaney":
|
||||
raise ValueError('norm must be one of None or "slaney"')
|
||||
|
||||
# frequencies of FFT bins in Hz
|
||||
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
|
||||
|
||||
# center points of the triangular mel filters
|
||||
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
|
||||
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
|
||||
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
|
||||
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
|
||||
|
||||
if triangularize_in_mel_space:
|
||||
# frequencies of FFT bins in Hz, but filters triangularized in mel space
|
||||
fft_bin_width = sampling_rate / (num_frequency_bins * 2)
|
||||
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
|
||||
filter_freqs = mel_freqs
|
||||
else:
|
||||
# frequencies of FFT bins in Hz
|
||||
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
|
||||
|
||||
mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
|
||||
|
||||
if norm is not None and norm == "slaney":
|
||||
@@ -218,6 +232,7 @@ def window_function(
|
||||
- `"boxcar"`: a rectangular window
|
||||
- `"hamming"`: the Hamming window
|
||||
- `"hann"`: the Hann window
|
||||
- `"povey"`: the Povey window
|
||||
|
||||
Args:
|
||||
window_length (`int`):
|
||||
@@ -243,6 +258,8 @@ def window_function(
|
||||
window = np.hamming(length)
|
||||
elif name in ["hann", "hann_window"]:
|
||||
window = np.hanning(length)
|
||||
elif name in ["povey"]:
|
||||
window = np.power(np.hanning(length), 0.85)
|
||||
else:
|
||||
raise ValueError(f"Unknown window function '{name}'")
|
||||
|
||||
@@ -281,6 +298,7 @@ def spectrogram(
|
||||
reference: float = 1.0,
|
||||
min_value: float = 1e-10,
|
||||
db_range: Optional[float] = None,
|
||||
remove_dc_offset: Optional[bool] = None,
|
||||
dtype: np.dtype = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@@ -363,6 +381,9 @@ def spectrogram(
|
||||
db_range (`float`, *optional*):
|
||||
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
||||
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
||||
remove_dc_offset (`bool`, *optional*):
|
||||
Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
|
||||
order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
|
||||
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
||||
Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
|
||||
`np.complex64`.
|
||||
@@ -414,6 +435,9 @@ def spectrogram(
|
||||
for frame_idx in range(num_frames):
|
||||
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
|
||||
|
||||
if remove_dc_offset:
|
||||
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
|
||||
|
||||
if preemphasis is not None:
|
||||
buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
|
||||
buffer[0] *= 1 - preemphasis
|
||||
|
||||
@@ -45,6 +45,10 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
|
||||
self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected))
|
||||
|
||||
inputs = np.array([60, 100, 200, 1000, 1001, 2000])
|
||||
expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
|
||||
self.assertTrue(np.allclose(hertz_to_mel(inputs, "kaldi"), expected))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
hertz_to_mel(100, mel_scale=None)
|
||||
|
||||
@@ -63,6 +67,10 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
expected = np.array([60, 100, 200, 1000, 1001, 2000])
|
||||
self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected))
|
||||
|
||||
inputs = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
|
||||
expected = np.array([60, 100, 200, 1000, 1001, 2000])
|
||||
self.assertTrue(np.allclose(mel_to_hertz(inputs, "kaldi"), expected))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
mel_to_hertz(100, mel_scale=None)
|
||||
|
||||
@@ -89,6 +97,18 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(mel_filters.shape, (513, 13))
|
||||
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=513,
|
||||
num_mel_filters=13,
|
||||
min_frequency=100,
|
||||
max_frequency=4000,
|
||||
sampling_rate=16000,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
triangularize_in_mel_space=True,
|
||||
)
|
||||
self.assertEqual(mel_filters.shape, (513, 13))
|
||||
|
||||
def test_mel_filter_bank_htk(self):
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=16,
|
||||
@@ -153,6 +173,39 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(mel_filters, expected))
|
||||
|
||||
def test_mel_filter_bank_kaldi(self):
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=16,
|
||||
num_mel_filters=4,
|
||||
min_frequency=0,
|
||||
max_frequency=2000,
|
||||
sampling_rate=4000,
|
||||
norm=None,
|
||||
mel_scale="kaldi",
|
||||
triangularize_in_mel_space=True,
|
||||
)
|
||||
# fmt: off
|
||||
expected = np.array(
|
||||
[[0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.6086, 0.0000, 0.0000, 0.0000],
|
||||
[0.8689, 0.1311, 0.0000, 0.0000],
|
||||
[0.4110, 0.5890, 0.0000, 0.0000],
|
||||
[0.0036, 0.9964, 0.0000, 0.0000],
|
||||
[0.0000, 0.6366, 0.3634, 0.0000],
|
||||
[0.0000, 0.3027, 0.6973, 0.0000],
|
||||
[0.0000, 0.0000, 0.9964, 0.0036],
|
||||
[0.0000, 0.0000, 0.7135, 0.2865],
|
||||
[0.0000, 0.0000, 0.4507, 0.5493],
|
||||
[0.0000, 0.0000, 0.2053, 0.7947],
|
||||
[0.0000, 0.0000, 0.0000, 0.9752],
|
||||
[0.0000, 0.0000, 0.0000, 0.7585],
|
||||
[0.0000, 0.0000, 0.0000, 0.5539],
|
||||
[0.0000, 0.0000, 0.0000, 0.3599],
|
||||
[0.0000, 0.0000, 0.0000, 0.1756]]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(mel_filters, expected, atol=5e-5))
|
||||
|
||||
def test_mel_filter_bank_slaney_norm(self):
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=16,
|
||||
@@ -271,6 +324,58 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
self.assertEqual(spec.shape, (257, 732))
|
||||
self.assertTrue(np.allclose(spec[:64, 400], expected))
|
||||
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=256,
|
||||
num_mel_filters=400,
|
||||
min_frequency=20,
|
||||
max_frequency=8000,
|
||||
sampling_rate=16000,
|
||||
norm=None,
|
||||
mel_scale="kaldi",
|
||||
triangularize_in_mel_space=True,
|
||||
)
|
||||
|
||||
mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "povey", periodic=False),
|
||||
frame_length=400,
|
||||
hop_length=160,
|
||||
fft_length=512,
|
||||
power=2.0,
|
||||
center=False,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
preemphasis=0.97,
|
||||
mel_filters=mel_filters,
|
||||
log_mel="log",
|
||||
mel_floor=1.1920928955078125e-07,
|
||||
remove_dc_offset=True,
|
||||
)
|
||||
self.assertEqual(spec.shape, (400, 584))
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,
|
||||
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
|
||||
-6.52463769, -7.73677889, -15.94238515, -15.94238515,
|
||||
-15.94238515, -15.94238515, -4.18650018, -3.37195286,
|
||||
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
|
||||
-4.70190154, -2.4217066 , -15.94238515, -15.94238515,
|
||||
-15.94238515, -15.94238515, -5.62755239, -3.53385194,
|
||||
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
|
||||
-9.43303023, -8.77480925, -15.94238515, -15.94238515,
|
||||
-15.94238515, -15.94238515, -4.2951092 , -5.51585994,
|
||||
-15.94238515, -15.94238515, -15.94238515, -4.40151721,
|
||||
-3.95228878, -15.94238515, -15.94238515, -15.94238515,
|
||||
-6.10365415, -4.59494697, -15.94238515, -15.94238515,
|
||||
-15.94238515, -8.10727767, -6.2585298 , -15.94238515,
|
||||
-15.94238515, -15.94238515, -5.60161702, -4.47217004,
|
||||
-15.94238515, -15.94238515, -15.94238515, -5.91641988]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))
|
||||
|
||||
def test_spectrogram_center_padding(self):
|
||||
waveform = self._load_datasamples(1)[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user