Add dithering to the Speech2TextFeatureExtractor API. (#34638)
* Add dithering to the `Speech2TextFeatureExtractor` API.
- in kaldi : 4a8b7f6732/src/feat/feature-window.cc (L145)
- with dithering without a seed, the features become non-deterministic due
to small Gaussian noise added to the audio (i.e. 2 runs lead to little
different outputs)
* update the PR
- add dithering also for WhisperFeatureExtractor
- not adding to Wav2Vec2FeatureExtractor (no FBANK computation)
* add unit-tests for dithering, fix docstrings
* ruff
* utils/check_copies.py --fix_and_overwrite
* update code, add seed to unit-test
* adding explanation of dithering
This commit is contained in:
@@ -390,6 +390,7 @@ def spectrogram(
|
|||||||
center: bool = True,
|
center: bool = True,
|
||||||
pad_mode: str = "reflect",
|
pad_mode: str = "reflect",
|
||||||
onesided: bool = True,
|
onesided: bool = True,
|
||||||
|
dither: float = 0.0,
|
||||||
preemphasis: Optional[float] = None,
|
preemphasis: Optional[float] = None,
|
||||||
mel_filters: Optional[np.ndarray] = None,
|
mel_filters: Optional[np.ndarray] = None,
|
||||||
mel_floor: float = 1e-10,
|
mel_floor: float = 1e-10,
|
||||||
@@ -460,6 +461,12 @@ def spectrogram(
|
|||||||
onesided (`bool`, *optional*, defaults to `True`):
|
onesided (`bool`, *optional*, defaults to `True`):
|
||||||
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
|
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
|
||||||
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
|
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
|
||||||
|
dither (`float`, *optional*, defaults to 0.0):
|
||||||
|
Adds dithering. In other words, adds a small Gaussian noise to each frame.
|
||||||
|
E.g. use 4.0 to add dithering with a normal distribution centered
|
||||||
|
around 0.0 with standard deviation 4.0, 0.0 means no dithering.
|
||||||
|
Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
|
||||||
|
values for signals with hard-zero sections, when VAD cutoff is present in the signal.
|
||||||
preemphasis (`float`, *optional*)
|
preemphasis (`float`, *optional*)
|
||||||
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
|
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
|
||||||
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
|
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
|
||||||
@@ -540,6 +547,9 @@ def spectrogram(
|
|||||||
for frame_idx in range(num_frames):
|
for frame_idx in range(num_frames):
|
||||||
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
|
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
|
||||||
|
|
||||||
|
if dither != 0.0:
|
||||||
|
buffer[:frame_length] += dither * np.random.randn(frame_length)
|
||||||
|
|
||||||
if remove_dc_offset:
|
if remove_dc_offset:
|
||||||
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
|
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
|
||||||
|
|
||||||
@@ -591,6 +601,7 @@ def spectrogram_batch(
|
|||||||
center: bool = True,
|
center: bool = True,
|
||||||
pad_mode: str = "reflect",
|
pad_mode: str = "reflect",
|
||||||
onesided: bool = True,
|
onesided: bool = True,
|
||||||
|
dither: float = 0.0,
|
||||||
preemphasis: Optional[float] = None,
|
preemphasis: Optional[float] = None,
|
||||||
mel_filters: Optional[np.ndarray] = None,
|
mel_filters: Optional[np.ndarray] = None,
|
||||||
mel_floor: float = 1e-10,
|
mel_floor: float = 1e-10,
|
||||||
@@ -653,6 +664,10 @@ def spectrogram_batch(
|
|||||||
The padding strategy when `center` is `True`.
|
The padding strategy when `center` is `True`.
|
||||||
onesided (`bool`, *optional*, defaults to `True`):
|
onesided (`bool`, *optional*, defaults to `True`):
|
||||||
If True, returns a one-sided spectrogram for real input signals.
|
If True, returns a one-sided spectrogram for real input signals.
|
||||||
|
dither (`float`, *optional*, defaults to 0.0):
|
||||||
|
Adds dithering. In other words, adds a small Gaussian noise to each frame.
|
||||||
|
E.g. use 4.0 to add dithering with a normal distribution centered
|
||||||
|
around 0.0 with standard deviation 4.0, 0.0 means no dithering.
|
||||||
preemphasis (`float`, *optional*):
|
preemphasis (`float`, *optional*):
|
||||||
Applies a pre-emphasis filter to each frame.
|
Applies a pre-emphasis filter to each frame.
|
||||||
mel_filters (`np.ndarray`, *optional*):
|
mel_filters (`np.ndarray`, *optional*):
|
||||||
@@ -741,6 +756,9 @@ def spectrogram_batch(
|
|||||||
timestep = frame_idx * hop_length
|
timestep = frame_idx * hop_length
|
||||||
buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
|
buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
|
||||||
|
|
||||||
|
if dither != 0.0:
|
||||||
|
buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape)
|
||||||
|
|
||||||
if remove_dc_offset:
|
if remove_dc_offset:
|
||||||
buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
|
buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,13 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
Number of Mel-frequency bins.
|
Number of Mel-frequency bins.
|
||||||
padding_value (`float`, *optional*, defaults to 0.0):
|
padding_value (`float`, *optional*, defaults to 0.0):
|
||||||
The value that is used to fill the padding vectors.
|
The value that is used to fill the padding vectors.
|
||||||
|
dither (`float`, *optional*, defaults to 0.0):
|
||||||
|
Adds dithering. In other words, adds a small Gaussian noise to each frame.
|
||||||
|
E.g. use 4.0 to add dithering with a normal distribution centered
|
||||||
|
around 0.0 with standard deviation 4.0 (assuming [-32k,+32k] range of kaldi waveform).
|
||||||
|
The value 0.0 means no dithering.
|
||||||
|
Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
|
||||||
|
values for signals with hard-zero sections, when VAD cutoff is present in the signal.
|
||||||
do_ceptral_normalize (`bool`, *optional*, defaults to `True`):
|
do_ceptral_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features.
|
Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features.
|
||||||
normalize_means (`bool`, *optional*, defaults to `True`):
|
normalize_means (`bool`, *optional*, defaults to `True`):
|
||||||
@@ -68,6 +75,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
num_mel_bins=80,
|
num_mel_bins=80,
|
||||||
padding_value=0.0,
|
padding_value=0.0,
|
||||||
|
dither=0.0,
|
||||||
do_ceptral_normalize=True,
|
do_ceptral_normalize=True,
|
||||||
normalize_means=True,
|
normalize_means=True,
|
||||||
normalize_vars=True,
|
normalize_vars=True,
|
||||||
@@ -75,6 +83,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
):
|
):
|
||||||
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
||||||
self.num_mel_bins = num_mel_bins
|
self.num_mel_bins = num_mel_bins
|
||||||
|
self.dither = dither
|
||||||
self.do_ceptral_normalize = do_ceptral_normalize
|
self.do_ceptral_normalize = do_ceptral_normalize
|
||||||
self.normalize_means = normalize_means
|
self.normalize_means = normalize_means
|
||||||
self.normalize_vars = normalize_vars
|
self.normalize_vars = normalize_vars
|
||||||
@@ -106,7 +115,12 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
|
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
|
||||||
if is_speech_available():
|
if is_speech_available():
|
||||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||||
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
|
features = ta_kaldi.fbank(
|
||||||
|
waveform,
|
||||||
|
dither=self.dither,
|
||||||
|
num_mel_bins=self.num_mel_bins,
|
||||||
|
sample_frequency=self.sampling_rate,
|
||||||
|
)
|
||||||
features = features.numpy()
|
features = features.numpy()
|
||||||
else:
|
else:
|
||||||
waveform = np.squeeze(waveform)
|
waveform = np.squeeze(waveform)
|
||||||
@@ -118,6 +132,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
fft_length=512,
|
fft_length=512,
|
||||||
power=2.0,
|
power=2.0,
|
||||||
center=False,
|
center=False,
|
||||||
|
dither=self.dither,
|
||||||
preemphasis=0.97,
|
preemphasis=0.97,
|
||||||
mel_filters=self.mel_filters,
|
mel_filters=self.mel_filters,
|
||||||
log_mel="log",
|
log_mel="log",
|
||||||
|
|||||||
@@ -57,6 +57,14 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
Size of the Fourier transform.
|
Size of the Fourier transform.
|
||||||
padding_value (`float`, *optional*, defaults to 0.0):
|
padding_value (`float`, *optional*, defaults to 0.0):
|
||||||
Padding value used to pad the audio. Should correspond to silences.
|
Padding value used to pad the audio. Should correspond to silences.
|
||||||
|
dither (`float`, *optional*, defaults to 0.0):
|
||||||
|
Adds dithering. In other words, adds a small Gaussian noise to each frame.
|
||||||
|
E.g. use 0.0001 to add dithering with a normal distribution centered
|
||||||
|
around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
|
||||||
|
The value 0.0 means no dithering.
|
||||||
|
Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
|
||||||
|
the high log_mel_fbank values for signals with hard-zero sections,
|
||||||
|
when VAD cutoff is present in the signal.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_input_names = ["input_features"]
|
model_input_names = ["input_features"]
|
||||||
@@ -69,6 +77,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
chunk_length=30,
|
chunk_length=30,
|
||||||
n_fft=400,
|
n_fft=400,
|
||||||
padding_value=0.0,
|
padding_value=0.0,
|
||||||
|
dither=0.0,
|
||||||
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
|
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -85,6 +94,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
self.n_samples = chunk_length * sampling_rate
|
self.n_samples = chunk_length * sampling_rate
|
||||||
self.nb_max_frames = self.n_samples // hop_length
|
self.nb_max_frames = self.n_samples // hop_length
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
|
self.dither = dither
|
||||||
self.mel_filters = mel_filter_bank(
|
self.mel_filters = mel_filter_bank(
|
||||||
num_frequency_bins=1 + n_fft // 2,
|
num_frequency_bins=1 + n_fft // 2,
|
||||||
num_mel_filters=feature_size,
|
num_mel_filters=feature_size,
|
||||||
@@ -114,6 +124,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
frame_length=self.n_fft,
|
frame_length=self.n_fft,
|
||||||
hop_length=self.hop_length,
|
hop_length=self.hop_length,
|
||||||
power=2.0,
|
power=2.0,
|
||||||
|
dither=self.dither,
|
||||||
mel_filters=self.mel_filters,
|
mel_filters=self.mel_filters,
|
||||||
log_mel="log10",
|
log_mel="log10",
|
||||||
)
|
)
|
||||||
@@ -132,6 +143,12 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
waveform = torch.from_numpy(waveform).to(device, torch.float32)
|
waveform = torch.from_numpy(waveform).to(device, torch.float32)
|
||||||
window = torch.hann_window(self.n_fft, device=device)
|
window = torch.hann_window(self.n_fft, device=device)
|
||||||
|
|
||||||
|
# Note: it would be better to dither the chunked waveform,
|
||||||
|
# so overlapping signal does not get the same dithering.
|
||||||
|
# But, chunking is happening inside pytorch, so it is here.
|
||||||
|
if self.dither != 0.0:
|
||||||
|
waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)
|
||||||
|
|
||||||
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
|
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
|
||||||
magnitudes = stft[..., :-1].abs() ** 2
|
magnitudes = stft[..., :-1].abs() ** 2
|
||||||
|
|
||||||
|
|||||||
@@ -144,6 +144,40 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
|||||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
|
def test_dither(self):
|
||||||
|
np.random.seed(42) # seed the dithering randn()
|
||||||
|
|
||||||
|
# Tests that features with and without little dithering are similar, but not the same
|
||||||
|
dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
|
||||||
|
dict_no_dither["dither"] = 0.0
|
||||||
|
|
||||||
|
dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
|
||||||
|
dict_dither["dither"] = 1.0
|
||||||
|
|
||||||
|
feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
|
||||||
|
feature_extractor_dither = self.feature_extraction_class(**dict_dither)
|
||||||
|
|
||||||
|
# create three inputs of length 800, 1000, and 1200
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
|
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||||
|
|
||||||
|
# compute features
|
||||||
|
input_features_no_dither = feature_extractor_no_dither(
|
||||||
|
np_speech_inputs, padding=True, return_tensors="np"
|
||||||
|
).input_features
|
||||||
|
input_features_dither = feature_extractor_dither(
|
||||||
|
np_speech_inputs, padding=True, return_tensors="np"
|
||||||
|
).input_features
|
||||||
|
|
||||||
|
# test there is a difference between features (there's added noise to input signal)
|
||||||
|
diff = input_features_dither - input_features_no_dither
|
||||||
|
|
||||||
|
# features are not identical
|
||||||
|
self.assertTrue(np.abs(diff).mean() > 1e-5)
|
||||||
|
# features are not too different
|
||||||
|
self.assertTrue(np.abs(diff).mean() <= 1e-3)
|
||||||
|
self.assertTrue(np.abs(diff).max() <= 1e-2)
|
||||||
|
|
||||||
def test_cepstral_mean_and_variance_normalization(self):
|
def test_cepstral_mean_and_variance_normalization(self):
|
||||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
|
|||||||
@@ -200,6 +200,40 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
|||||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
|
def test_dither(self):
|
||||||
|
np.random.seed(42) # seed the dithering randn()
|
||||||
|
|
||||||
|
# Tests that features with and without little dithering are similar, but not the same
|
||||||
|
dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
|
||||||
|
dict_no_dither["dither"] = 0.0
|
||||||
|
|
||||||
|
dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
|
||||||
|
dict_dither["dither"] = 0.00003 # approx. 1/32k
|
||||||
|
|
||||||
|
feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
|
||||||
|
feature_extractor_dither = self.feature_extraction_class(**dict_dither)
|
||||||
|
|
||||||
|
# create three inputs of length 800, 1000, and 1200
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
|
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||||
|
|
||||||
|
# compute features
|
||||||
|
input_features_no_dither = feature_extractor_no_dither(
|
||||||
|
np_speech_inputs, padding=True, return_tensors="np"
|
||||||
|
).input_features
|
||||||
|
input_features_dither = feature_extractor_dither(
|
||||||
|
np_speech_inputs, padding=True, return_tensors="np"
|
||||||
|
).input_features
|
||||||
|
|
||||||
|
# test there is a difference between features (there's added noise to input signal)
|
||||||
|
diff = input_features_dither - input_features_no_dither
|
||||||
|
|
||||||
|
# features are not identical
|
||||||
|
self.assertTrue(np.abs(diff).mean() > 1e-6)
|
||||||
|
# features are not too different
|
||||||
|
self.assertTrue(np.abs(diff).mean() <= 1e-4)
|
||||||
|
self.assertTrue(np.abs(diff).max() <= 1e-3)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_double_precision_pad(self):
|
def test_double_precision_pad(self):
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
Reference in New Issue
Block a user