Porting the torchaudio kaldi fbank implementation to audio_utils (#26182)
* add kaldi fbank * make style * add herz_to_mel_kaldi tests * add mel to hertz kaldi test * integration tests * correct test and remove comment * make style * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change parameter name * Apply suggestions from Arthur review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update remove_dc_offset description * fix bug + make style * fix error in using np.exp instead of np.power * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -45,6 +45,10 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
|
||||
self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected))
|
||||
|
||||
inputs = np.array([60, 100, 200, 1000, 1001, 2000])
|
||||
expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
|
||||
self.assertTrue(np.allclose(hertz_to_mel(inputs, "kaldi"), expected))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
hertz_to_mel(100, mel_scale=None)
|
||||
|
||||
@@ -63,6 +67,10 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
expected = np.array([60, 100, 200, 1000, 1001, 2000])
|
||||
self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected))
|
||||
|
||||
inputs = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
|
||||
expected = np.array([60, 100, 200, 1000, 1001, 2000])
|
||||
self.assertTrue(np.allclose(mel_to_hertz(inputs, "kaldi"), expected))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
mel_to_hertz(100, mel_scale=None)
|
||||
|
||||
@@ -89,6 +97,18 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(mel_filters.shape, (513, 13))
|
||||
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=513,
|
||||
num_mel_filters=13,
|
||||
min_frequency=100,
|
||||
max_frequency=4000,
|
||||
sampling_rate=16000,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
triangularize_in_mel_space=True,
|
||||
)
|
||||
self.assertEqual(mel_filters.shape, (513, 13))
|
||||
|
||||
def test_mel_filter_bank_htk(self):
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=16,
|
||||
@@ -153,6 +173,39 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(mel_filters, expected))
|
||||
|
||||
def test_mel_filter_bank_kaldi(self):
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=16,
|
||||
num_mel_filters=4,
|
||||
min_frequency=0,
|
||||
max_frequency=2000,
|
||||
sampling_rate=4000,
|
||||
norm=None,
|
||||
mel_scale="kaldi",
|
||||
triangularize_in_mel_space=True,
|
||||
)
|
||||
# fmt: off
|
||||
expected = np.array(
|
||||
[[0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.6086, 0.0000, 0.0000, 0.0000],
|
||||
[0.8689, 0.1311, 0.0000, 0.0000],
|
||||
[0.4110, 0.5890, 0.0000, 0.0000],
|
||||
[0.0036, 0.9964, 0.0000, 0.0000],
|
||||
[0.0000, 0.6366, 0.3634, 0.0000],
|
||||
[0.0000, 0.3027, 0.6973, 0.0000],
|
||||
[0.0000, 0.0000, 0.9964, 0.0036],
|
||||
[0.0000, 0.0000, 0.7135, 0.2865],
|
||||
[0.0000, 0.0000, 0.4507, 0.5493],
|
||||
[0.0000, 0.0000, 0.2053, 0.7947],
|
||||
[0.0000, 0.0000, 0.0000, 0.9752],
|
||||
[0.0000, 0.0000, 0.0000, 0.7585],
|
||||
[0.0000, 0.0000, 0.0000, 0.5539],
|
||||
[0.0000, 0.0000, 0.0000, 0.3599],
|
||||
[0.0000, 0.0000, 0.0000, 0.1756]]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(mel_filters, expected, atol=5e-5))
|
||||
|
||||
def test_mel_filter_bank_slaney_norm(self):
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=16,
|
||||
@@ -271,6 +324,58 @@ class AudioUtilsFunctionTester(unittest.TestCase):
|
||||
self.assertEqual(spec.shape, (257, 732))
|
||||
self.assertTrue(np.allclose(spec[:64, 400], expected))
|
||||
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=256,
|
||||
num_mel_filters=400,
|
||||
min_frequency=20,
|
||||
max_frequency=8000,
|
||||
sampling_rate=16000,
|
||||
norm=None,
|
||||
mel_scale="kaldi",
|
||||
triangularize_in_mel_space=True,
|
||||
)
|
||||
|
||||
mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
|
||||
|
||||
spec = spectrogram(
|
||||
waveform,
|
||||
window_function(400, "povey", periodic=False),
|
||||
frame_length=400,
|
||||
hop_length=160,
|
||||
fft_length=512,
|
||||
power=2.0,
|
||||
center=False,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
preemphasis=0.97,
|
||||
mel_filters=mel_filters,
|
||||
log_mel="log",
|
||||
mel_floor=1.1920928955078125e-07,
|
||||
remove_dc_offset=True,
|
||||
)
|
||||
self.assertEqual(spec.shape, (400, 584))
|
||||
|
||||
# fmt: off
|
||||
expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,
|
||||
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
|
||||
-6.52463769, -7.73677889, -15.94238515, -15.94238515,
|
||||
-15.94238515, -15.94238515, -4.18650018, -3.37195286,
|
||||
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
|
||||
-4.70190154, -2.4217066 , -15.94238515, -15.94238515,
|
||||
-15.94238515, -15.94238515, -5.62755239, -3.53385194,
|
||||
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
|
||||
-9.43303023, -8.77480925, -15.94238515, -15.94238515,
|
||||
-15.94238515, -15.94238515, -4.2951092 , -5.51585994,
|
||||
-15.94238515, -15.94238515, -15.94238515, -4.40151721,
|
||||
-3.95228878, -15.94238515, -15.94238515, -15.94238515,
|
||||
-6.10365415, -4.59494697, -15.94238515, -15.94238515,
|
||||
-15.94238515, -8.10727767, -6.2585298 , -15.94238515,
|
||||
-15.94238515, -15.94238515, -5.60161702, -4.47217004,
|
||||
-15.94238515, -15.94238515, -15.94238515, -5.91641988]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))
|
||||
|
||||
def test_spectrogram_center_padding(self):
|
||||
waveform = self._load_datasamples(1)[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user