[Wav2Vec2] Fix normalization for non-padded tensors (#13512)
* finalize * Apply suggestions from code review * finish cleaner implementation * more tests * small fix * finish * up
This commit is contained in:
committed by
GitHub
parent
c63fcabfe9
commit
d7b3b709d0
@@ -341,7 +341,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
||||
|
||||
return processed_features
|
||||
|
||||
def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs):
|
||||
def _get_padding_strategies(self, padding=False, max_length=None):
|
||||
"""
|
||||
Find the correct padding strategy
|
||||
"""
|
||||
|
||||
@@ -93,10 +93,13 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
@staticmethod
|
||||
def utterance_cmvn(
|
||||
x: np.ndarray, input_length: int, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True
|
||||
x: np.ndarray,
|
||||
input_length: int,
|
||||
normalize_means: Optional[bool] = True,
|
||||
normalize_vars: Optional[bool] = True,
|
||||
padding_value: float = 0.0,
|
||||
) -> np.ndarray:
|
||||
# make sure we normalie float32 arrays
|
||||
|
||||
mean = x[:input_length].mean(axis=0)
|
||||
square_sums = (x[:input_length] ** 2).sum(axis=0)
|
||||
|
||||
@@ -107,15 +110,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
std = np.sqrt(np.maximum(var, 1e-10))
|
||||
x = np.divide(x, std)
|
||||
|
||||
if x.shape[0] > input_length:
|
||||
x[input_length:] = padding_value
|
||||
|
||||
# make sure array is in float32
|
||||
x = x.astype(np.float32)
|
||||
|
||||
return x
|
||||
|
||||
def normalize(self, input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
|
||||
def normalize(
|
||||
self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None
|
||||
) -> List[np.ndarray]:
|
||||
lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
|
||||
return [
|
||||
self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars)
|
||||
for x, n in zip(input_values, input_lengths)
|
||||
self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars, self.padding_value)
|
||||
for x, n in zip(input_features, lengths)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
@@ -197,7 +206,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||
)
|
||||
|
||||
# make sure input is in list format
|
||||
if is_batched and not isinstance(raw_speech[0], np.ndarray):
|
||||
raw_speech = [np.asarray(speech) for speech in raw_speech]
|
||||
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
||||
@@ -225,21 +233,25 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if "attention_mask" in padded_inputs:
|
||||
input_lengths = padded_inputs["attention_mask"].sum(-1)
|
||||
else:
|
||||
padded_input_values = padded_inputs["input_features"]
|
||||
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
|
||||
# make sure list is in array format
|
||||
input_features = padded_inputs.get("input_features")
|
||||
if isinstance(input_features[0], list):
|
||||
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
||||
|
||||
attention_mask = padded_inputs.get("attention_mask")
|
||||
if attention_mask is not None:
|
||||
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
|
||||
|
||||
# Utterance-level cepstral mean and variance normalization
|
||||
if self.do_ceptral_normalize:
|
||||
input_features = padded_inputs["input_features"]
|
||||
|
||||
# make sure list is in array format
|
||||
if isinstance(input_features[0], list):
|
||||
input_features = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
||||
|
||||
padded_inputs["input_features"] = self.normalize(input_features, input_lengths=input_lengths)
|
||||
attention_mask = (
|
||||
np.array(attention_mask, dtype=np.bool)
|
||||
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
|
||||
else None
|
||||
)
|
||||
padded_inputs["input_features"] = self.normalize(
|
||||
padded_inputs["input_features"], attention_mask=attention_mask
|
||||
)
|
||||
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
@@ -79,13 +79,25 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
@staticmethod
|
||||
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
|
||||
def zero_mean_unit_var_norm(
|
||||
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Every array in the list is normalized to have zero mean and unit variance
|
||||
"""
|
||||
normed_input_values = [
|
||||
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths)
|
||||
]
|
||||
if attention_mask is not None:
|
||||
attention_mask = np.array(attention_mask, np.bool)
|
||||
normed_input_values = []
|
||||
|
||||
for vector, length in zip(input_values, attention_mask.sum(-1)):
|
||||
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
|
||||
if length > normed_slice.shape[0]:
|
||||
normed_slice[length:] = padding_value
|
||||
|
||||
normed_input_values.append(normed_slice)
|
||||
else:
|
||||
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
|
||||
|
||||
return normed_input_values
|
||||
|
||||
def __call__(
|
||||
@@ -172,14 +184,6 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||
)
|
||||
|
||||
# make sure input is in list format
|
||||
if is_batched and not isinstance(raw_speech[0], np.ndarray):
|
||||
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
|
||||
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
||||
raw_speech = np.asarray(raw_speech, dtype=np.float32)
|
||||
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
|
||||
raw_speech = raw_speech.astype(np.float32)
|
||||
|
||||
# always return batch
|
||||
if not is_batched:
|
||||
raw_speech = [raw_speech]
|
||||
@@ -196,19 +200,33 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
||||
return_attention_mask=return_attention_mask,
|
||||
)
|
||||
|
||||
if "attention_mask" in padded_inputs:
|
||||
input_lengths = padded_inputs["attention_mask"].sum(-1)
|
||||
else:
|
||||
padded_input_values = padded_inputs["input_values"]
|
||||
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
|
||||
# convert input values to correct format
|
||||
input_values = padded_inputs["input_values"]
|
||||
if not isinstance(input_values[0], np.ndarray):
|
||||
padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values]
|
||||
elif (
|
||||
not isinstance(input_values, np.ndarray)
|
||||
and isinstance(input_values[0], np.ndarray)
|
||||
and input_values[0].dtype is np.float64
|
||||
):
|
||||
padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values]
|
||||
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.float64:
|
||||
padded_inputs["input_values"] = input_values.astype(np.float32)
|
||||
|
||||
if isinstance(padded_inputs["input_values"][0], np.ndarray):
|
||||
padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]]
|
||||
# convert attention_mask to correct format
|
||||
attention_mask = padded_inputs.get("attention_mask")
|
||||
if attention_mask is not None:
|
||||
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
|
||||
|
||||
# zero-mean and unit-variance normalization
|
||||
if self.do_normalize:
|
||||
attention_mask = (
|
||||
attention_mask
|
||||
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
|
||||
else None
|
||||
)
|
||||
padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
|
||||
padded_inputs["input_values"], input_lengths=input_lengths
|
||||
padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
|
||||
)
|
||||
|
||||
if return_tensors is not None:
|
||||
|
||||
@@ -136,18 +136,49 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
def test_cepstral_mean_and_variance_normalization(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
inputs = feature_extractor(speech_inputs, padding=True, return_tensors="np", return_attention_mask=True)
|
||||
|
||||
paddings = ["longest", "max_length", "do_not_pad"]
|
||||
max_lengths = [None, 16, None]
|
||||
var_tolerances = [1e-3, 1e-3, 1e-1]
|
||||
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
|
||||
|
||||
inputs = feature_extractor(
|
||||
speech_inputs, padding=padding, max_length=max_length, return_attention_mask=True
|
||||
)
|
||||
input_features = inputs.input_features
|
||||
attention_mask = inputs.attention_mask
|
||||
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
|
||||
fbank_feat_lengths = [np.sum(x) for x in attention_mask]
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))
|
||||
|
||||
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
|
||||
_check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]])
|
||||
_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
|
||||
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
|
||||
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)
|
||||
|
||||
def test_cepstral_mean_and_variance_normalization_np(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
|
||||
paddings = ["longest", "max_length", "do_not_pad"]
|
||||
max_lengths = [None, 16, None]
|
||||
var_tolerances = [1e-3, 1e-3, 1e-1]
|
||||
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
|
||||
inputs = feature_extractor(
|
||||
speech_inputs, max_length=max_length, padding=padding, return_tensors="np", return_attention_mask=True
|
||||
)
|
||||
input_features = inputs.input_features
|
||||
attention_mask = inputs.attention_mask
|
||||
fbank_feat_lengths = [np.sum(x) for x in attention_mask]
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))
|
||||
|
||||
_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
|
||||
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
|
||||
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)
|
||||
|
||||
def test_cepstral_mean_and_variance_normalization_trunc(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
|
||||
@@ -120,21 +120,45 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_zero_mean_unit_variance_normalization(self):
|
||||
def test_zero_mean_unit_variance_normalization_np(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = feat_extract(speech_inputs, padding="longest", return_tensors="np")
|
||||
|
||||
paddings = ["longest", "max_length", "do_not_pad"]
|
||||
max_lengths = [None, 1600, None]
|
||||
for max_length, padding in zip(max_lengths, paddings):
|
||||
processed = feat_extract(speech_inputs, padding=padding, max_length=max_length, return_tensors="np")
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0, :800])
|
||||
_check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
_check_zero_mean_unit_variance(input_values[2])
|
||||
_check_zero_mean_unit_variance(input_values[0][:800])
|
||||
_check_zero_mean_unit_variance(input_values[1][:1000])
|
||||
_check_zero_mean_unit_variance(input_values[2][:1200])
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc(self):
|
||||
def test_zero_mean_unit_variance_normalization(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
lengths = range(800, 1400, 200)
|
||||
speech_inputs = [floats_list((1, x))[0] for x in lengths]
|
||||
|
||||
paddings = ["longest", "max_length", "do_not_pad"]
|
||||
max_lengths = [None, 1600, None]
|
||||
|
||||
for max_length, padding in zip(max_lengths, paddings):
|
||||
processed = feat_extract(speech_inputs, max_length=max_length, padding=padding)
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0][:800])
|
||||
_check_zero_mean_unit_variance(input_values[1][:1000])
|
||||
_check_zero_mean_unit_variance(input_values[2][:1200])
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc_np(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = feat_extract(
|
||||
|
||||
Reference in New Issue
Block a user