[Whisper] Add rescaling function with do_normalize (#21263)
* add `zero_mean_unit_var_norm` function * normalize before MEL computation * fixup * add simple test * quality * Update tests/models/whisper/test_feature_extraction_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixup * use attention masks if padding was applied * Update based on review Co-authored-by: bofeng huang <bofenghuang7@gmail.com> --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: bofeng huang <bofenghuang7@gmail.com>
This commit is contained in:
@@ -215,6 +215,29 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
|
|
||||||
return log_spec
|
return log_spec
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
|
||||||
|
def zero_mean_unit_var_norm(
|
||||||
|
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
|
||||||
|
) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Every array in the list is normalized to have zero mean and unit variance
|
||||||
|
"""
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = np.array(attention_mask, np.int32)
|
||||||
|
normed_input_values = []
|
||||||
|
|
||||||
|
for vector, length in zip(input_values, attention_mask.sum(-1)):
|
||||||
|
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
||||||
|
if length < normed_slice.shape[0]:
|
||||||
|
normed_slice[length:] = padding_value
|
||||||
|
|
||||||
|
normed_input_values.append(normed_slice)
|
||||||
|
else:
|
||||||
|
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
|
||||||
|
|
||||||
|
return normed_input_values
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
||||||
@@ -225,6 +248,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
padding: Optional[str] = "max_length",
|
padding: Optional[str] = "max_length",
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
sampling_rate: Optional[int] = None,
|
sampling_rate: Optional[int] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
@@ -266,6 +290,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
pipeline.
|
pipeline.
|
||||||
padding_value (`float`, defaults to 0.0):
|
padding_value (`float`, defaults to 0.0):
|
||||||
The value that is used to fill the padding values / vectors.
|
The value that is used to fill the padding values / vectors.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
|
||||||
|
improve the performance of the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if sampling_rate is not None:
|
if sampling_rate is not None:
|
||||||
@@ -312,6 +339,18 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
# make sure list is in array format
|
# make sure list is in array format
|
||||||
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
|
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
|
||||||
|
|
||||||
|
if return_attention_mask:
|
||||||
|
# rescale from sample (48000) to feature (3000)
|
||||||
|
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
|
||||||
|
|
||||||
|
# zero-mean and unit-variance normalization
|
||||||
|
if do_normalize:
|
||||||
|
padded_inputs["input_features"] = self.zero_mean_unit_var_norm(
|
||||||
|
padded_inputs["input_features"],
|
||||||
|
attention_mask=padded_inputs["attention_mask"],
|
||||||
|
padding_value=self.padding_value,
|
||||||
|
)
|
||||||
|
|
||||||
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]]
|
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]]
|
||||||
|
|
||||||
if isinstance(input_features[0], List):
|
if isinstance(input_features[0], List):
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import is_speech_available
|
from transformers import is_speech_available
|
||||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
||||||
@@ -198,8 +199,6 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
|||||||
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
||||||
|
|
||||||
def _load_datasamples(self, num_samples):
|
def _load_datasamples(self, num_samples):
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
# automatic decoding with librispeech
|
# automatic decoding with librispeech
|
||||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||||
@@ -222,3 +221,12 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
|||||||
feaure_extractor = WhisperFeatureExtractor()
|
feaure_extractor = WhisperFeatureExtractor()
|
||||||
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features
|
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features
|
||||||
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||||
|
|
||||||
|
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
|
||||||
|
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
|
audio = self._load_datasamples(1)[0]
|
||||||
|
audio = ((audio - audio.min()) / (audio.max() - audio.min())) * 65535 # Rescale to [0, 65535] to show issue
|
||||||
|
audio = feat_extract.zero_mean_unit_var_norm([audio], attention_mask=None)[0]
|
||||||
|
|
||||||
|
self.assertTrue(np.all(np.mean(audio) < 1e-3))
|
||||||
|
self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))
|
||||||
|
|||||||
Reference in New Issue
Block a user