[audio utils] fix fft_bin_width computation (#36603)
* fix fft_bin_width computation * update docstring + enforce correct params * update test with correct value * udpate test * update feature extractors for concerned models * update * make * udpate docstring * udpate docstring
This commit is contained in:
@@ -293,7 +293,7 @@ def mel_filter_bank(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_frequency_bins (`int`):
|
num_frequency_bins (`int`):
|
||||||
Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
|
Number of frequency bins (should be the same as `n_fft // 2 + 1` where `n_fft` is the size of the Fourier Transform used to compute the spectrogram).
|
||||||
num_mel_filters (`int`):
|
num_mel_filters (`int`):
|
||||||
Number of mel filters to generate.
|
Number of mel filters to generate.
|
||||||
min_frequency (`float`):
|
min_frequency (`float`):
|
||||||
@@ -317,6 +317,12 @@ def mel_filter_bank(
|
|||||||
if norm is not None and norm != "slaney":
|
if norm is not None and norm != "slaney":
|
||||||
raise ValueError('norm must be one of None or "slaney"')
|
raise ValueError('norm must be one of None or "slaney"')
|
||||||
|
|
||||||
|
if num_frequency_bins < 2:
|
||||||
|
raise ValueError(f"Require num_frequency_bins: {num_frequency_bins} >= 2")
|
||||||
|
|
||||||
|
if min_frequency > max_frequency:
|
||||||
|
raise ValueError(f"Require min_frequency: {min_frequency} <= max_frequency: {max_frequency}")
|
||||||
|
|
||||||
# center points of the triangular mel filters
|
# center points of the triangular mel filters
|
||||||
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
|
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
|
||||||
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
|
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
|
||||||
@@ -325,7 +331,7 @@ def mel_filter_bank(
|
|||||||
|
|
||||||
if triangularize_in_mel_space:
|
if triangularize_in_mel_space:
|
||||||
# frequencies of FFT bins in Hz, but filters triangularized in mel space
|
# frequencies of FFT bins in Hz, but filters triangularized in mel space
|
||||||
fft_bin_width = sampling_rate / (num_frequency_bins * 2)
|
fft_bin_width = sampling_rate / ((num_frequency_bins - 1) * 2)
|
||||||
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
|
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
|
||||||
filter_freqs = mel_freqs
|
filter_freqs = mel_freqs
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
|
|
||||||
if not is_speech_available():
|
if not is_speech_available():
|
||||||
mel_filters = mel_filter_bank(
|
mel_filters = mel_filter_bank(
|
||||||
num_frequency_bins=256,
|
num_frequency_bins=257,
|
||||||
num_mel_filters=self.num_mel_bins,
|
num_mel_filters=self.num_mel_bins,
|
||||||
min_frequency=20,
|
min_frequency=20,
|
||||||
max_frequency=sampling_rate // 2,
|
max_frequency=sampling_rate // 2,
|
||||||
@@ -101,7 +101,7 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
triangularize_in_mel_space=True,
|
triangularize_in_mel_space=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
|
self.mel_filters = mel_filters
|
||||||
self.window = window_function(400, "hann", periodic=False)
|
self.window = window_function(400, "hann", periodic=False)
|
||||||
|
|
||||||
def _extract_fbank_features(
|
def _extract_fbank_features(
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
mel_filters = mel_filter_bank(
|
mel_filters = mel_filter_bank(
|
||||||
num_frequency_bins=256,
|
num_frequency_bins=257,
|
||||||
num_mel_filters=self.num_mel_bins,
|
num_mel_filters=self.num_mel_bins,
|
||||||
min_frequency=20,
|
min_frequency=20,
|
||||||
max_frequency=sampling_rate // 2,
|
max_frequency=sampling_rate // 2,
|
||||||
@@ -84,7 +84,7 @@ class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
triangularize_in_mel_space=True,
|
triangularize_in_mel_space=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
|
self.mel_filters = mel_filters
|
||||||
self.window = window_function(400, "povey", periodic=False)
|
self.window = window_function(400, "povey", periodic=False)
|
||||||
|
|
||||||
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
|
|
||||||
if not is_speech_available():
|
if not is_speech_available():
|
||||||
mel_filters = mel_filter_bank(
|
mel_filters = mel_filter_bank(
|
||||||
num_frequency_bins=256,
|
num_frequency_bins=257,
|
||||||
num_mel_filters=self.num_mel_bins,
|
num_mel_filters=self.num_mel_bins,
|
||||||
min_frequency=20,
|
min_frequency=20,
|
||||||
max_frequency=sampling_rate // 2,
|
max_frequency=sampling_rate // 2,
|
||||||
@@ -101,7 +101,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
triangularize_in_mel_space=True,
|
triangularize_in_mel_space=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
|
self.mel_filters = mel_filters
|
||||||
self.window = window_function(400, "povey", periodic=False)
|
self.window = window_function(400, "povey", periodic=False)
|
||||||
|
|
||||||
def _extract_fbank_features(
|
def _extract_fbank_features(
|
||||||
|
|||||||
@@ -194,26 +194,38 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
|||||||
triangularize_in_mel_space=True,
|
triangularize_in_mel_space=True,
|
||||||
)
|
)
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
# here the expected values from torchaudio.compliance.kaldi.get_mel_banks
|
||||||
|
# note that we compute values in float64 while they do it in float32
|
||||||
expected = np.array(
|
expected = np.array(
|
||||||
[[0.0000, 0.0000, 0.0000, 0.0000],
|
[
|
||||||
[0.6086, 0.0000, 0.0000, 0.0000],
|
[0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.0000000000000000],
|
||||||
[0.8689, 0.1311, 0.0000, 0.0000],
|
[0.6457883715629578, 0.0000000000000000, 0.0000000000000000, 0.0000000000000000],
|
||||||
[0.4110, 0.5890, 0.0000, 0.0000],
|
[0.8044781088829041, 0.1955219060182571, 0.0000000000000000, 0.0000000000000000],
|
||||||
[0.0036, 0.9964, 0.0000, 0.0000],
|
[0.3258901536464691, 0.6741098165512085, 0.0000000000000000, 0.0000000000000000],
|
||||||
[0.0000, 0.6366, 0.3634, 0.0000],
|
[0.0000000000000000, 0.9021250009536743, 0.0978749766945839, 0.0000000000000000],
|
||||||
[0.0000, 0.3027, 0.6973, 0.0000],
|
[0.0000000000000000, 0.5219038724899292, 0.4780961275100708, 0.0000000000000000],
|
||||||
[0.0000, 0.0000, 0.9964, 0.0036],
|
[0.0000000000000000, 0.1771058291196823, 0.8228941559791565, 0.0000000000000000],
|
||||||
[0.0000, 0.0000, 0.7135, 0.2865],
|
[0.0000000000000000, 0.0000000000000000, 0.8616894483566284, 0.1383105516433716],
|
||||||
[0.0000, 0.0000, 0.4507, 0.5493],
|
[0.0000000000000000, 0.0000000000000000, 0.5710380673408508, 0.4289619624614716],
|
||||||
[0.0000, 0.0000, 0.2053, 0.7947],
|
[0.0000000000000000, 0.0000000000000000, 0.3015440106391907, 0.6984559893608093],
|
||||||
[0.0000, 0.0000, 0.0000, 0.9752],
|
[0.0000000000000000, 0.0000000000000000, 0.0503356307744980, 0.9496643543243408],
|
||||||
[0.0000, 0.0000, 0.0000, 0.7585],
|
[0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.8150880336761475],
|
||||||
[0.0000, 0.0000, 0.0000, 0.5539],
|
[0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.5938932299613953],
|
||||||
[0.0000, 0.0000, 0.0000, 0.3599],
|
[0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.3851676583290100],
|
||||||
[0.0000, 0.0000, 0.0000, 0.1756]]
|
[0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.1875794380903244],
|
||||||
|
],
|
||||||
|
dtype=np.float64,
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
self.assertTrue(np.allclose(mel_filters, expected, atol=5e-5))
|
|
||||||
|
# kaldi implementation does not compute values for last fft bin
|
||||||
|
# indeed, they enforce max_frequency <= sampling_rate / 2 and
|
||||||
|
# therefore they know that last fft bin filter bank values will be all 0
|
||||||
|
# and pad after with zeros
|
||||||
|
# to comply with our API for `mel_filter_bank`, we need to also pad here
|
||||||
|
expected = np.pad(expected, ((0, 1), (0, 0)))
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(mel_filters, expected))
|
||||||
|
|
||||||
def test_mel_filter_bank_slaney_norm(self):
|
def test_mel_filter_bank_slaney_norm(self):
|
||||||
mel_filters = mel_filter_bank(
|
mel_filters = mel_filter_bank(
|
||||||
@@ -369,7 +381,7 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
|||||||
self.assertTrue(np.allclose(spec[:64, 400], expected))
|
self.assertTrue(np.allclose(spec[:64, 400], expected))
|
||||||
|
|
||||||
mel_filters = mel_filter_bank(
|
mel_filters = mel_filter_bank(
|
||||||
num_frequency_bins=256,
|
num_frequency_bins=257,
|
||||||
num_mel_filters=400,
|
num_mel_filters=400,
|
||||||
min_frequency=20,
|
min_frequency=20,
|
||||||
max_frequency=8000,
|
max_frequency=8000,
|
||||||
@@ -379,8 +391,6 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
|||||||
triangularize_in_mel_space=True,
|
triangularize_in_mel_space=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
|
|
||||||
|
|
||||||
spec = spectrogram(
|
spec = spectrogram(
|
||||||
waveform,
|
waveform,
|
||||||
window_function(400, "povey", periodic=False),
|
window_function(400, "povey", periodic=False),
|
||||||
@@ -510,7 +520,7 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
|||||||
self.assertTrue(np.allclose(spec_list[2][:64, 400], expected3))
|
self.assertTrue(np.allclose(spec_list[2][:64, 400], expected3))
|
||||||
|
|
||||||
mel_filters = mel_filter_bank(
|
mel_filters = mel_filter_bank(
|
||||||
num_frequency_bins=256,
|
num_frequency_bins=257,
|
||||||
num_mel_filters=400,
|
num_mel_filters=400,
|
||||||
min_frequency=20,
|
min_frequency=20,
|
||||||
max_frequency=8000,
|
max_frequency=8000,
|
||||||
@@ -520,8 +530,6 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
|||||||
triangularize_in_mel_space=True,
|
triangularize_in_mel_space=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
|
|
||||||
|
|
||||||
spec_list = spectrogram_batch(
|
spec_list = spectrogram_batch(
|
||||||
waveform_list,
|
waveform_list,
|
||||||
window_function(400, "povey", periodic=False),
|
window_function(400, "povey", periodic=False),
|
||||||
|
|||||||
Reference in New Issue
Block a user