From b4b562d83415d842c081f571e0ec325f40f276aa Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 16 Jul 2021 18:07:08 +0100 Subject: [PATCH] [Wav2Vec2] Padded vectors should not allowed to be sampled (#12764) * fix_torch_device_generate_test * remove @ * finish * correct script * correct script --- .../wav2vec2/run_wav2vec2_pretrain_flax.py | 2 + .../models/wav2vec2/modeling_flax_wav2vec2.py | 14 +++-- .../models/wav2vec2/modeling_wav2vec2.py | 55 ++++++++++++------- tests/test_modeling_flax_wav2vec2.py | 42 ++++++++++++++ tests/test_modeling_wav2vec2.py | 31 +++++++++++ 5 files changed, 117 insertions(+), 27 deletions(-) diff --git a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py index 774b0674d2..7eb286b496 100755 --- a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py +++ b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py @@ -176,6 +176,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining: batch_size = batch["input_values"].shape[0] + attention_mask = None if batch["attention_mask"] is not None: output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1)) attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8) @@ -198,6 +199,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining: batch["sampled_negative_indices"] = _sample_negative_indices( (batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)), self.model.config.num_negatives, + attention_mask=attention_mask, ) return batch diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 1e463234a2..e95e21f909 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -174,7 +174,7 @@ def _compute_mask_indices( return spec_aug_mask -def _sample_negative_indices(features_shape: Tuple, num_negatives: int): +def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attention_mask: Optional[np.ndarray] = None): """ Sample `num_negatives` vectors from feature vectors. """ @@ -186,11 +186,13 @@ def _sample_negative_indices(features_shape: Tuple, num_negatives: int): ) # get `num_negatives` random vector indices from the same utterance - sampled_negative_indices = np.random.randint( - low=0, - high=sequence_length - 1, - size=(batch_size, num_negatives * sequence_length), - ) + sampled_negative_indices = [] + for batch_idx in range(batch_size): + high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1 + sampled_indices_slice = np.random.randint(0, high, size=(num_negatives * sequence_length,)) + sampled_negative_indices.append(sampled_indices_slice) + + sampled_negative_indices = np.asarray(sampled_negative_indices, dtype=np.int32) # generate indices of the positive vectors themselves, repeat them `num_negatives` times feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten() diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ea9b6cc592..3a51897da7 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -877,6 +877,18 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): return input_lengths + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + WAV_2_VEC_2_START_DOCSTRING = r""" Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations @@ -1044,19 +1056,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): extract_features = extract_features.transpose(1, 2) if attention_mask is not None: - # compute real output lengths according to convolution formula - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) - - attention_mask = torch.zeros( - extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device - ) - - # these two operations makes sure that all values - # before the output lengths indices are attended to - attention_mask[ - (torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1) - ] = 1 - attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + # compute reduced attention_mask correponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) hidden_states, extract_features = self.feature_projection(extract_features) hidden_states = self._mask_hidden_states( @@ -1111,7 +1112,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): self.wav2vec2.feature_extractor._freeze_parameters() @staticmethod - def _sample_negatives(features: torch.FloatTensor, num_negatives: int): + def _sample_negatives( + features: torch.FloatTensor, num_negatives: int, attention_mask: Optional[torch.LongTensor] = None + ): """ Sample `num_negatives` vectors from feature vectors. """ @@ -1125,12 +1128,15 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): with torch.no_grad(): # get `num_negatives` random vector indices from the same utterance - sampled_negative_indices = torch.randint( - low=0, - high=sequence_length - 1, - size=(batch_size, num_negatives * sequence_length), - device=features.device, - ) + sampled_negative_indices = [] + for batch_idx in range(batch_size): + high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1 + sampled_indices_slice = torch.randint( + 0, high, size=(num_negatives * sequence_length,), device=features.device + ) + sampled_negative_indices.append(sampled_indices_slice) + + sampled_negative_indices = torch.stack(sampled_negative_indices) # generate indices of the positive vectors themselves, repeat them `num_negatives` times feature_indices = ( @@ -1263,7 +1269,14 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): if self.training: # for training, we sample negatives # 3. sample K negatives (distractors) quantized states for contrastive loss - negative_quantized_features = self._sample_negatives(quantized_features, self.config.num_negatives) + # if attention_mask is passed, make sure that padded feature vectors cannot be sampled + if attention_mask is not None: + # compute reduced attention_mask correponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + negative_quantized_features = self._sample_negatives( + quantized_features, self.config.num_negatives, attention_mask=attention_mask + ) # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf diff --git a/tests/test_modeling_flax_wav2vec2.py b/tests/test_modeling_flax_wav2vec2.py index b8d27b6190..66dd1e0611 100644 --- a/tests/test_modeling_flax_wav2vec2.py +++ b/tests/test_modeling_flax_wav2vec2.py @@ -306,6 +306,48 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase): # => this means that `unique()` yields a single value for `hidden_size` dim self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1)) + def test_sample_negatives_with_attn_mask(self): + batch_size = 2 + sequence_length = 10 + hidden_size = 4 + num_negatives = 3 + + features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape( + sequence_length, hidden_size + ) # each value in vector consits of same value + + # second half of last input tensor is padded + attention_mask = np.ones((batch_size, sequence_length), dtype=np.int8) + attention_mask[-1, sequence_length // 2 :] = 0 + + forbidden_indices = ( + np.arange(sequence_length // 2, sequence_length, dtype=np.int32) + (batch_size - 1) * sequence_length + ).tolist() + + features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size)) + + negative_indices = _sample_negative_indices(features.shape, num_negatives, attention_mask=attention_mask) + + # make sure that no padding tokens are sampled + self.assertTrue(all([idx not in negative_indices for idx in forbidden_indices])) + + features = features.reshape(-1, hidden_size) # BTC => (BxT)C + # take negative vectors from sampled indices + sampled_negatives = features[negative_indices.reshape(-1)] + negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose( + 2, 0, 1, 3 + ) + + self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size)) + + # make sure no negatively sampled vector is actually a positive one + for negative in negatives: + self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0) + + # make sure that full vectors are sampled and not just slices of vectors + # => this means that `unique()` yields a single value for `hidden_size` dim + self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1)) + @require_flax @require_datasets diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index cb54132dd5..8b269b88bc 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -633,6 +633,37 @@ class Wav2Vec2UtilsTest(unittest.TestCase): # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1)) + def test_sample_negatives_with_attn_mask(self): + batch_size = 2 + sequence_length = 10 + hidden_size = 4 + num_negatives = 3 + + # second half of last input tensor is padded + attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device) + attention_mask[-1, sequence_length // 2 :] = 0 + + features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view( + sequence_length, hidden_size + ) # each value in vector consits of same value + features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous() + + # replace masked feature vectors with -100 to test that those are not sampled + features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100) + + negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask) + + self.assertTrue((negatives >= 0).all().item()) + + self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size)) + + # make sure no negatively sampled vector is actually a positive one + for negative in negatives: + self.assertTrue(((negative - features) == 0).sum() == 0.0) + + # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim + self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1)) + @require_torch @require_datasets