From 60eb416a13a65a55a7fbeadffbf7aa8dea6e6eed Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 10 Sep 2021 15:27:16 +0200 Subject: [PATCH] [Wav2Vec2] Fix normalization for non-padded tensors (#13512) * finalize * Apply suggestions from code review * finish cleaner implementation * more tests * small fix * finish * up --- .../feature_extraction_sequence_utils.py | 2 +- .../feature_extraction_speech_to_text.py | 48 +++++++++------ .../wav2vec2/feature_extraction_wav2vec2.py | 58 ++++++++++++------- .../test_feature_extraction_speech_to_text.py | 51 ++++++++++++---- tests/test_feature_extraction_wav2vec2.py | 44 ++++++++++---- 5 files changed, 144 insertions(+), 59 deletions(-) diff --git a/src/transformers/feature_extraction_sequence_utils.py b/src/transformers/feature_extraction_sequence_utils.py index 8a6c1a9af3..69a5511208 100644 --- a/src/transformers/feature_extraction_sequence_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -341,7 +341,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin): return processed_features - def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs): + def _get_padding_strategies(self, padding=False, max_length=None): """ Find the correct padding strategy """ diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py index e3f28e441b..ccc53cb6f2 100644 --- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py +++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py @@ -93,10 +93,13 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): @staticmethod def utterance_cmvn( - x: np.ndarray, input_length: int, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True + x: np.ndarray, + input_length: int, + normalize_means: Optional[bool] = True, + normalize_vars: Optional[bool] = True, + padding_value: float = 0.0, ) -> np.ndarray: # make sure we normalie float32 arrays - mean = x[:input_length].mean(axis=0) square_sums = (x[:input_length] ** 2).sum(axis=0) @@ -107,15 +110,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): std = np.sqrt(np.maximum(var, 1e-10)) x = np.divide(x, std) + if x.shape[0] > input_length: + x[input_length:] = padding_value + # make sure array is in float32 x = x.astype(np.float32) return x - def normalize(self, input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]: + def normalize( + self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None + ) -> List[np.ndarray]: + lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features] return [ - self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars) - for x, n in zip(input_values, input_lengths) + self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars, self.padding_value) + for x, n in zip(input_features, lengths) ] def __call__( @@ -197,7 +206,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) ) - # make sure input is in list format if is_batched and not isinstance(raw_speech[0], np.ndarray): raw_speech = [np.asarray(speech) for speech in raw_speech] elif not is_batched and not isinstance(raw_speech, np.ndarray): @@ -225,21 +233,25 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): **kwargs, ) - if "attention_mask" in padded_inputs: - input_lengths = padded_inputs["attention_mask"].sum(-1) - else: - padded_input_values = padded_inputs["input_features"] - input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])] + # make sure list is in array format + input_features = padded_inputs.get("input_features") + if isinstance(input_features[0], list): + padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] + + attention_mask = padded_inputs.get("attention_mask") + if attention_mask is not None: + padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask] # Utterance-level cepstral mean and variance normalization if self.do_ceptral_normalize: - input_features = padded_inputs["input_features"] - - # make sure list is in array format - if isinstance(input_features[0], list): - input_features = [np.asarray(feature, dtype=np.float32) for feature in input_features] - - padded_inputs["input_features"] = self.normalize(input_features, input_lengths=input_lengths) + attention_mask = ( + np.array(attention_mask, dtype=np.bool) + if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD + else None + ) + padded_inputs["input_features"] = self.normalize( + padded_inputs["input_features"], attention_mask=attention_mask + ) if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index 01c6966637..5afe85411e 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -79,13 +79,25 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): self.do_normalize = do_normalize @staticmethod - def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]: + def zero_mean_unit_var_norm( + input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0 + ) -> List[np.ndarray]: """ Every array in the list is normalized to have zero mean and unit variance """ - normed_input_values = [ - (x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths) - ] + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.bool) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length > normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + return normed_input_values def __call__( @@ -172,14 +184,6 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) ) - # make sure input is in list format - if is_batched and not isinstance(raw_speech[0], np.ndarray): - raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] - elif not is_batched and not isinstance(raw_speech, np.ndarray): - raw_speech = np.asarray(raw_speech, dtype=np.float32) - elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64: - raw_speech = raw_speech.astype(np.float32) - # always return batch if not is_batched: raw_speech = [raw_speech] @@ -196,19 +200,33 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): return_attention_mask=return_attention_mask, ) - if "attention_mask" in padded_inputs: - input_lengths = padded_inputs["attention_mask"].sum(-1) - else: - padded_input_values = padded_inputs["input_values"] - input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])] + # convert input values to correct format + input_values = padded_inputs["input_values"] + if not isinstance(input_values[0], np.ndarray): + padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values] + elif ( + not isinstance(input_values, np.ndarray) + and isinstance(input_values[0], np.ndarray) + and input_values[0].dtype is np.float64 + ): + padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values] + elif isinstance(input_values, np.ndarray) and input_values.dtype is np.float64: + padded_inputs["input_values"] = input_values.astype(np.float32) - if isinstance(padded_inputs["input_values"][0], np.ndarray): - padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]] + # convert attention_mask to correct format + attention_mask = padded_inputs.get("attention_mask") + if attention_mask is not None: + padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask] # zero-mean and unit-variance normalization if self.do_normalize: + attention_mask = ( + attention_mask + if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD + else None + ) padded_inputs["input_values"] = self.zero_mean_unit_var_norm( - padded_inputs["input_values"], input_lengths=input_lengths + padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value ) if return_tensors is not None: diff --git a/tests/test_feature_extraction_speech_to_text.py b/tests/test_feature_extraction_speech_to_text.py index 6ed23a7157..5d160081d3 100644 --- a/tests/test_feature_extraction_speech_to_text.py +++ b/tests/test_feature_extraction_speech_to_text.py @@ -136,18 +136,49 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt def test_cepstral_mean_and_variance_normalization(self): feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] - inputs = feature_extractor(speech_inputs, padding=True, return_tensors="np", return_attention_mask=True) - input_features = inputs.input_features - attention_mask = inputs.attention_mask - fbank_feat_lengths = np.sum(attention_mask == 1, axis=1) - def _check_zero_mean_unit_variance(input_vector): - self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3)) - self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3)) + paddings = ["longest", "max_length", "do_not_pad"] + max_lengths = [None, 16, None] + var_tolerances = [1e-3, 1e-3, 1e-1] + for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances): - _check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]]) - _check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]]) - _check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]]) + inputs = feature_extractor( + speech_inputs, padding=padding, max_length=max_length, return_attention_mask=True + ) + input_features = inputs.input_features + attention_mask = inputs.attention_mask + fbank_feat_lengths = [np.sum(x) for x in attention_mask] + + def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3): + self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3)) + self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol)) + + _check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol) + _check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol) + _check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol) + + def test_cepstral_mean_and_variance_normalization_np(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + + paddings = ["longest", "max_length", "do_not_pad"] + max_lengths = [None, 16, None] + var_tolerances = [1e-3, 1e-3, 1e-1] + for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances): + inputs = feature_extractor( + speech_inputs, max_length=max_length, padding=padding, return_tensors="np", return_attention_mask=True + ) + input_features = inputs.input_features + attention_mask = inputs.attention_mask + fbank_feat_lengths = [np.sum(x) for x in attention_mask] + + def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3): + self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3)) + self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol)) + + _check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol) + _check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol) + _check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol) def test_cepstral_mean_and_variance_normalization_trunc(self): feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) diff --git a/tests/test_feature_extraction_wav2vec2.py b/tests/test_feature_extraction_wav2vec2.py index 2bbf9bd58b..81c8d384fd 100644 --- a/tests/test_feature_extraction_wav2vec2.py +++ b/tests/test_feature_extraction_wav2vec2.py @@ -120,21 +120,45 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) - def test_zero_mean_unit_variance_normalization(self): + def test_zero_mean_unit_variance_normalization_np(self): feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] - processed = feat_extract(speech_inputs, padding="longest", return_tensors="np") - input_values = processed.input_values - def _check_zero_mean_unit_variance(input_vector): - self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3) - self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3) + paddings = ["longest", "max_length", "do_not_pad"] + max_lengths = [None, 1600, None] + for max_length, padding in zip(max_lengths, paddings): + processed = feat_extract(speech_inputs, padding=padding, max_length=max_length, return_tensors="np") + input_values = processed.input_values - _check_zero_mean_unit_variance(input_values[0, :800]) - _check_zero_mean_unit_variance(input_values[1, :1000]) - _check_zero_mean_unit_variance(input_values[2]) + def _check_zero_mean_unit_variance(input_vector): + self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3) + self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3) - def test_zero_mean_unit_variance_normalization_trunc(self): + _check_zero_mean_unit_variance(input_values[0][:800]) + _check_zero_mean_unit_variance(input_values[1][:1000]) + _check_zero_mean_unit_variance(input_values[2][:1200]) + + def test_zero_mean_unit_variance_normalization(self): + feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + lengths = range(800, 1400, 200) + speech_inputs = [floats_list((1, x))[0] for x in lengths] + + paddings = ["longest", "max_length", "do_not_pad"] + max_lengths = [None, 1600, None] + + for max_length, padding in zip(max_lengths, paddings): + processed = feat_extract(speech_inputs, max_length=max_length, padding=padding) + input_values = processed.input_values + + def _check_zero_mean_unit_variance(input_vector): + self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3) + self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3) + + _check_zero_mean_unit_variance(input_values[0][:800]) + _check_zero_mean_unit_variance(input_values[1][:1000]) + _check_zero_mean_unit_variance(input_values[2][:1200]) + + def test_zero_mean_unit_variance_normalization_trunc_np(self): feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] processed = feat_extract(