Add dithering to the Speech2TextFeatureExtractor API. (#34638)

* Add dithering to the `Speech2TextFeatureExtractor` API.

- in kaldi : 4a8b7f6732/src/feat/feature-window.cc (L145)
- with dithering without a seed, the features become non-deterministic due
  to small Gaussian noise added to the audio (i.e. 2 runs lead to little
  different outputs)

* update the PR

- add dithering also for WhisperFeatureExtractor
- not adding to Wav2Vec2FeatureExtractor (no FBANK computation)

* add unit-tests for dithering, fix docstrings

* ruff

* utils/check_copies.py --fix_and_overwrite

* update code, add seed to unit-test

* adding explanation of dithering
This commit is contained in:
Karel Vesely
2025-02-19 11:50:02 +01:00
committed by GitHub
parent 9f51dc2535
commit 1a81d774b1
5 changed files with 119 additions and 1 deletions

View File

@@ -390,6 +390,7 @@ def spectrogram(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
dither: float = 0.0,
preemphasis: Optional[float] = None,
mel_filters: Optional[np.ndarray] = None,
mel_floor: float = 1e-10,
@@ -460,6 +461,12 @@ def spectrogram(
onesided (`bool`, *optional*, defaults to `True`):
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 4.0 to add dithering with a normal distribution centered
around 0.0 with standard deviation 4.0, 0.0 means no dithering.
Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
values for signals with hard-zero sections, when VAD cutoff is present in the signal.
preemphasis (`float`, *optional*)
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
@@ -540,6 +547,9 @@ def spectrogram(
for frame_idx in range(num_frames):
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
if dither != 0.0:
buffer[:frame_length] += dither * np.random.randn(frame_length)
if remove_dc_offset:
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
@@ -591,6 +601,7 @@ def spectrogram_batch(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
dither: float = 0.0,
preemphasis: Optional[float] = None,
mel_filters: Optional[np.ndarray] = None,
mel_floor: float = 1e-10,
@@ -653,6 +664,10 @@ def spectrogram_batch(
The padding strategy when `center` is `True`.
onesided (`bool`, *optional*, defaults to `True`):
If True, returns a one-sided spectrogram for real input signals.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 4.0 to add dithering with a normal distribution centered
around 0.0 with standard deviation 4.0, 0.0 means no dithering.
preemphasis (`float`, *optional*):
Applies a pre-emphasis filter to each frame.
mel_filters (`np.ndarray`, *optional*):
@@ -741,6 +756,9 @@ def spectrogram_batch(
timestep = frame_idx * hop_length
buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
if dither != 0.0:
buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape)
if remove_dc_offset:
buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)

View File

@@ -52,6 +52,13 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
Number of Mel-frequency bins.
padding_value (`float`, *optional*, defaults to 0.0):
The value that is used to fill the padding vectors.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 4.0 to add dithering with a normal distribution centered
around 0.0 with standard deviation 4.0 (assuming [-32k,+32k] range of kaldi waveform).
The value 0.0 means no dithering.
Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
values for signals with hard-zero sections, when VAD cutoff is present in the signal.
do_ceptral_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features.
normalize_means (`bool`, *optional*, defaults to `True`):
@@ -68,6 +75,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
sampling_rate=16000,
num_mel_bins=80,
padding_value=0.0,
dither=0.0,
do_ceptral_normalize=True,
normalize_means=True,
normalize_vars=True,
@@ -75,6 +83,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.num_mel_bins = num_mel_bins
self.dither = dither
self.do_ceptral_normalize = do_ceptral_normalize
self.normalize_means = normalize_means
self.normalize_vars = normalize_vars
@@ -106,7 +115,12 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
features = ta_kaldi.fbank(
waveform,
dither=self.dither,
num_mel_bins=self.num_mel_bins,
sample_frequency=self.sampling_rate,
)
features = features.numpy()
else:
waveform = np.squeeze(waveform)
@@ -118,6 +132,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
fft_length=512,
power=2.0,
center=False,
dither=self.dither,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",

View File

@@ -57,6 +57,14 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
Size of the Fourier transform.
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 0.0001 to add dithering with a normal distribution centered
around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
The value 0.0 means no dithering.
Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
the high log_mel_fbank values for signals with hard-zero sections,
when VAD cutoff is present in the signal.
"""
model_input_names = ["input_features"]
@@ -69,6 +77,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
chunk_length=30,
n_fft=400,
padding_value=0.0,
dither=0.0,
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
**kwargs,
):
@@ -85,6 +94,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
self.dither = dither
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
num_mel_filters=feature_size,
@@ -114,6 +124,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
dither=self.dither,
mel_filters=self.mel_filters,
log_mel="log10",
)
@@ -132,6 +143,12 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
waveform = torch.from_numpy(waveform).to(device, torch.float32)
window = torch.hann_window(self.n_fft, device=device)
# Note: it would be better to dither the chunked waveform,
# so overlapping signal does not get the same dithering.
# But, chunking is happening inside pytorch, so it is here.
if self.dither != 0.0:
waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

View File

@@ -144,6 +144,40 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_dither(self):
np.random.seed(42) # seed the dithering randn()
# Tests that features with and without little dithering are similar, but not the same
dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
dict_no_dither["dither"] = 0.0
dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
dict_dither["dither"] = 1.0
feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
feature_extractor_dither = self.feature_extraction_class(**dict_dither)
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
# compute features
input_features_no_dither = feature_extractor_no_dither(
np_speech_inputs, padding=True, return_tensors="np"
).input_features
input_features_dither = feature_extractor_dither(
np_speech_inputs, padding=True, return_tensors="np"
).input_features
# test there is a difference between features (there's added noise to input signal)
diff = input_features_dither - input_features_no_dither
# features are not identical
self.assertTrue(np.abs(diff).mean() > 1e-5)
# features are not too different
self.assertTrue(np.abs(diff).mean() <= 1e-3)
self.assertTrue(np.abs(diff).max() <= 1e-2)
def test_cepstral_mean_and_variance_normalization(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]

View File

@@ -200,6 +200,40 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_dither(self):
np.random.seed(42) # seed the dithering randn()
# Tests that features with and without little dithering are similar, but not the same
dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
dict_no_dither["dither"] = 0.0
dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
dict_dither["dither"] = 0.00003 # approx. 1/32k
feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
feature_extractor_dither = self.feature_extraction_class(**dict_dither)
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
# compute features
input_features_no_dither = feature_extractor_no_dither(
np_speech_inputs, padding=True, return_tensors="np"
).input_features
input_features_dither = feature_extractor_dither(
np_speech_inputs, padding=True, return_tensors="np"
).input_features
# test there is a difference between features (there's added noise to input signal)
diff = input_features_dither - input_features_no_dither
# features are not identical
self.assertTrue(np.abs(diff).mean() > 1e-6)
# features are not too different
self.assertTrue(np.abs(diff).mean() <= 1e-4)
self.assertTrue(np.abs(diff).max() <= 1e-3)
@require_torch
def test_double_precision_pad(self):
import torch