[Wav2Vec2] Fix normalization for non-padded tensors (#13512)
* finalize * Apply suggestions from code review * finish cleaner implementation * more tests * small fix * finish * up
This commit is contained in:
committed by
patrickvonplaten
parent
d12bbe4942
commit
60eb416a13
@@ -341,7 +341,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
|||||||
|
|
||||||
return processed_features
|
return processed_features
|
||||||
|
|
||||||
def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs):
|
def _get_padding_strategies(self, padding=False, max_length=None):
|
||||||
"""
|
"""
|
||||||
Find the correct padding strategy
|
Find the correct padding strategy
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -93,10 +93,13 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def utterance_cmvn(
|
def utterance_cmvn(
|
||||||
x: np.ndarray, input_length: int, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True
|
x: np.ndarray,
|
||||||
|
input_length: int,
|
||||||
|
normalize_means: Optional[bool] = True,
|
||||||
|
normalize_vars: Optional[bool] = True,
|
||||||
|
padding_value: float = 0.0,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
# make sure we normalie float32 arrays
|
# make sure we normalie float32 arrays
|
||||||
|
|
||||||
mean = x[:input_length].mean(axis=0)
|
mean = x[:input_length].mean(axis=0)
|
||||||
square_sums = (x[:input_length] ** 2).sum(axis=0)
|
square_sums = (x[:input_length] ** 2).sum(axis=0)
|
||||||
|
|
||||||
@@ -107,15 +110,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
std = np.sqrt(np.maximum(var, 1e-10))
|
std = np.sqrt(np.maximum(var, 1e-10))
|
||||||
x = np.divide(x, std)
|
x = np.divide(x, std)
|
||||||
|
|
||||||
|
if x.shape[0] > input_length:
|
||||||
|
x[input_length:] = padding_value
|
||||||
|
|
||||||
# make sure array is in float32
|
# make sure array is in float32
|
||||||
x = x.astype(np.float32)
|
x = x.astype(np.float32)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def normalize(self, input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
|
def normalize(
|
||||||
|
self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None
|
||||||
|
) -> List[np.ndarray]:
|
||||||
|
lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
|
||||||
return [
|
return [
|
||||||
self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars)
|
self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars, self.padding_value)
|
||||||
for x, n in zip(input_values, input_lengths)
|
for x, n in zip(input_features, lengths)
|
||||||
]
|
]
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -197,7 +206,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure input is in list format
|
|
||||||
if is_batched and not isinstance(raw_speech[0], np.ndarray):
|
if is_batched and not isinstance(raw_speech[0], np.ndarray):
|
||||||
raw_speech = [np.asarray(speech) for speech in raw_speech]
|
raw_speech = [np.asarray(speech) for speech in raw_speech]
|
||||||
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
||||||
@@ -225,21 +233,25 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "attention_mask" in padded_inputs:
|
# make sure list is in array format
|
||||||
input_lengths = padded_inputs["attention_mask"].sum(-1)
|
input_features = padded_inputs.get("input_features")
|
||||||
else:
|
if isinstance(input_features[0], list):
|
||||||
padded_input_values = padded_inputs["input_features"]
|
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
||||||
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
|
|
||||||
|
attention_mask = padded_inputs.get("attention_mask")
|
||||||
|
if attention_mask is not None:
|
||||||
|
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
|
||||||
|
|
||||||
# Utterance-level cepstral mean and variance normalization
|
# Utterance-level cepstral mean and variance normalization
|
||||||
if self.do_ceptral_normalize:
|
if self.do_ceptral_normalize:
|
||||||
input_features = padded_inputs["input_features"]
|
attention_mask = (
|
||||||
|
np.array(attention_mask, dtype=np.bool)
|
||||||
# make sure list is in array format
|
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
|
||||||
if isinstance(input_features[0], list):
|
else None
|
||||||
input_features = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
)
|
||||||
|
padded_inputs["input_features"] = self.normalize(
|
||||||
padded_inputs["input_features"] = self.normalize(input_features, input_lengths=input_lengths)
|
padded_inputs["input_features"], attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
if return_tensors is not None:
|
if return_tensors is not None:
|
||||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||||
|
|||||||
@@ -79,13 +79,25 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
|||||||
self.do_normalize = do_normalize
|
self.do_normalize = do_normalize
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
|
def zero_mean_unit_var_norm(
|
||||||
|
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
|
||||||
|
) -> List[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Every array in the list is normalized to have zero mean and unit variance
|
Every array in the list is normalized to have zero mean and unit variance
|
||||||
"""
|
"""
|
||||||
normed_input_values = [
|
if attention_mask is not None:
|
||||||
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths)
|
attention_mask = np.array(attention_mask, np.bool)
|
||||||
]
|
normed_input_values = []
|
||||||
|
|
||||||
|
for vector, length in zip(input_values, attention_mask.sum(-1)):
|
||||||
|
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
||||||
|
if length > normed_slice.shape[0]:
|
||||||
|
normed_slice[length:] = padding_value
|
||||||
|
|
||||||
|
normed_input_values.append(normed_slice)
|
||||||
|
else:
|
||||||
|
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
|
||||||
|
|
||||||
return normed_input_values
|
return normed_input_values
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -172,14 +184,6 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
|||||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure input is in list format
|
|
||||||
if is_batched and not isinstance(raw_speech[0], np.ndarray):
|
|
||||||
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
|
|
||||||
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
|
||||||
raw_speech = np.asarray(raw_speech, dtype=np.float32)
|
|
||||||
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
|
|
||||||
raw_speech = raw_speech.astype(np.float32)
|
|
||||||
|
|
||||||
# always return batch
|
# always return batch
|
||||||
if not is_batched:
|
if not is_batched:
|
||||||
raw_speech = [raw_speech]
|
raw_speech = [raw_speech]
|
||||||
@@ -196,19 +200,33 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
|||||||
return_attention_mask=return_attention_mask,
|
return_attention_mask=return_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "attention_mask" in padded_inputs:
|
# convert input values to correct format
|
||||||
input_lengths = padded_inputs["attention_mask"].sum(-1)
|
input_values = padded_inputs["input_values"]
|
||||||
else:
|
if not isinstance(input_values[0], np.ndarray):
|
||||||
padded_input_values = padded_inputs["input_values"]
|
padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values]
|
||||||
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
|
elif (
|
||||||
|
not isinstance(input_values, np.ndarray)
|
||||||
|
and isinstance(input_values[0], np.ndarray)
|
||||||
|
and input_values[0].dtype is np.float64
|
||||||
|
):
|
||||||
|
padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values]
|
||||||
|
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.float64:
|
||||||
|
padded_inputs["input_values"] = input_values.astype(np.float32)
|
||||||
|
|
||||||
if isinstance(padded_inputs["input_values"][0], np.ndarray):
|
# convert attention_mask to correct format
|
||||||
padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]]
|
attention_mask = padded_inputs.get("attention_mask")
|
||||||
|
if attention_mask is not None:
|
||||||
|
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
|
||||||
|
|
||||||
# zero-mean and unit-variance normalization
|
# zero-mean and unit-variance normalization
|
||||||
if self.do_normalize:
|
if self.do_normalize:
|
||||||
|
attention_mask = (
|
||||||
|
attention_mask
|
||||||
|
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
|
||||||
|
else None
|
||||||
|
)
|
||||||
padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
|
padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
|
||||||
padded_inputs["input_values"], input_lengths=input_lengths
|
padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_tensors is not None:
|
if return_tensors is not None:
|
||||||
|
|||||||
@@ -136,18 +136,49 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
|||||||
def test_cepstral_mean_and_variance_normalization(self):
|
def test_cepstral_mean_and_variance_normalization(self):
|
||||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
inputs = feature_extractor(speech_inputs, padding=True, return_tensors="np", return_attention_mask=True)
|
|
||||||
|
paddings = ["longest", "max_length", "do_not_pad"]
|
||||||
|
max_lengths = [None, 16, None]
|
||||||
|
var_tolerances = [1e-3, 1e-3, 1e-1]
|
||||||
|
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
|
||||||
|
|
||||||
|
inputs = feature_extractor(
|
||||||
|
speech_inputs, padding=padding, max_length=max_length, return_attention_mask=True
|
||||||
|
)
|
||||||
input_features = inputs.input_features
|
input_features = inputs.input_features
|
||||||
attention_mask = inputs.attention_mask
|
attention_mask = inputs.attention_mask
|
||||||
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
|
fbank_feat_lengths = [np.sum(x) for x in attention_mask]
|
||||||
|
|
||||||
def _check_zero_mean_unit_variance(input_vector):
|
def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
|
||||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
|
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))
|
||||||
|
|
||||||
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
|
||||||
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
|
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
|
||||||
_check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]])
|
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)
|
||||||
|
|
||||||
|
def test_cepstral_mean_and_variance_normalization_np(self):
|
||||||
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
|
|
||||||
|
paddings = ["longest", "max_length", "do_not_pad"]
|
||||||
|
max_lengths = [None, 16, None]
|
||||||
|
var_tolerances = [1e-3, 1e-3, 1e-1]
|
||||||
|
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
|
||||||
|
inputs = feature_extractor(
|
||||||
|
speech_inputs, max_length=max_length, padding=padding, return_tensors="np", return_attention_mask=True
|
||||||
|
)
|
||||||
|
input_features = inputs.input_features
|
||||||
|
attention_mask = inputs.attention_mask
|
||||||
|
fbank_feat_lengths = [np.sum(x) for x in attention_mask]
|
||||||
|
|
||||||
|
def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
|
||||||
|
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||||
|
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))
|
||||||
|
|
||||||
|
_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
|
||||||
|
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
|
||||||
|
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)
|
||||||
|
|
||||||
def test_cepstral_mean_and_variance_normalization_trunc(self):
|
def test_cepstral_mean_and_variance_normalization_trunc(self):
|
||||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
|
|||||||
@@ -120,21 +120,45 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
|||||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
def test_zero_mean_unit_variance_normalization(self):
|
def test_zero_mean_unit_variance_normalization_np(self):
|
||||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
processed = feat_extract(speech_inputs, padding="longest", return_tensors="np")
|
|
||||||
|
paddings = ["longest", "max_length", "do_not_pad"]
|
||||||
|
max_lengths = [None, 1600, None]
|
||||||
|
for max_length, padding in zip(max_lengths, paddings):
|
||||||
|
processed = feat_extract(speech_inputs, padding=padding, max_length=max_length, return_tensors="np")
|
||||||
input_values = processed.input_values
|
input_values = processed.input_values
|
||||||
|
|
||||||
def _check_zero_mean_unit_variance(input_vector):
|
def _check_zero_mean_unit_variance(input_vector):
|
||||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||||
|
|
||||||
_check_zero_mean_unit_variance(input_values[0, :800])
|
_check_zero_mean_unit_variance(input_values[0][:800])
|
||||||
_check_zero_mean_unit_variance(input_values[1, :1000])
|
_check_zero_mean_unit_variance(input_values[1][:1000])
|
||||||
_check_zero_mean_unit_variance(input_values[2])
|
_check_zero_mean_unit_variance(input_values[2][:1200])
|
||||||
|
|
||||||
def test_zero_mean_unit_variance_normalization_trunc(self):
|
def test_zero_mean_unit_variance_normalization(self):
|
||||||
|
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
|
lengths = range(800, 1400, 200)
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in lengths]
|
||||||
|
|
||||||
|
paddings = ["longest", "max_length", "do_not_pad"]
|
||||||
|
max_lengths = [None, 1600, None]
|
||||||
|
|
||||||
|
for max_length, padding in zip(max_lengths, paddings):
|
||||||
|
processed = feat_extract(speech_inputs, max_length=max_length, padding=padding)
|
||||||
|
input_values = processed.input_values
|
||||||
|
|
||||||
|
def _check_zero_mean_unit_variance(input_vector):
|
||||||
|
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||||
|
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||||
|
|
||||||
|
_check_zero_mean_unit_variance(input_values[0][:800])
|
||||||
|
_check_zero_mean_unit_variance(input_values[1][:1000])
|
||||||
|
_check_zero_mean_unit_variance(input_values[2][:1200])
|
||||||
|
|
||||||
|
def test_zero_mean_unit_variance_normalization_trunc_np(self):
|
||||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
processed = feat_extract(
|
processed = feat_extract(
|
||||||
|
|||||||
Reference in New Issue
Block a user