[Whisper] Add rescaling function with do_normalize (#21263)
* add `zero_mean_unit_var_norm` function * normalize before MEL computation * fixup * add simple test * quality * Update tests/models/whisper/test_feature_extraction_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixup * use attention masks if padding was applied * Update based on review Co-authored-by: bofeng huang <bofenghuang7@gmail.com> --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: bofeng huang <bofenghuang7@gmail.com>
This commit is contained in:
@@ -21,6 +21,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import is_speech_available
|
||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
|
||||
@@ -198,8 +199,6 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
@@ -222,3 +221,12 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
feaure_extractor = WhisperFeatureExtractor()
|
||||
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
audio = self._load_datasamples(1)[0]
|
||||
audio = ((audio - audio.min()) / (audio.max() - audio.min())) * 65535 # Rescale to [0, 65535] to show issue
|
||||
audio = feat_extract.zero_mean_unit_var_norm([audio], attention_mask=None)[0]
|
||||
|
||||
self.assertTrue(np.all(np.mean(audio) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))
|
||||
|
||||
Reference in New Issue
Block a user