diff --git a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py index a0a7d38f85..774b0674d2 100755 --- a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py +++ b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py @@ -174,11 +174,23 @@ class FlaxDataCollatorForWav2Vec2Pretraining: ) mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1]) + batch_size = batch["input_values"].shape[0] + + if batch["attention_mask"] is not None: + output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1)) + attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1 + attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") + # sample randomly masked indices batch["mask_time_indices"] = _compute_mask_indices( - (batch["input_values"].shape[0], mask_indices_seq_length), + (batch_size, mask_indices_seq_length), self.model.config.mask_time_prob, self.model.config.mask_time_length, + attention_mask=attention_mask, min_masks=2, ) diff --git a/examples/research_projects/wav2vec2/run_pretrain.py b/examples/research_projects/wav2vec2/run_pretrain.py index 491537b2eb..02dee12953 100755 --- a/examples/research_projects/wav2vec2/run_pretrain.py +++ b/examples/research_projects/wav2vec2/run_pretrain.py @@ -172,12 +172,33 @@ class DataCollatorForWav2Vec2Pretraining: ) mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1]) + batch_size = batch["input_values"].shape[0] + + # make sure that no loss is computed on padded inputs + if batch["attention_mask"] is not None: + # compute real output lengths according to convolution formula + output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1)).to( + torch.long + ) + + attention_mask = torch.zeros( + (batch_size, mask_indices_seq_length), dtype=torch.long, device=batch["input_values"].device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask[ + (torch.arange(attention_mask.shape[0], device=batch["input_values"].device), output_lengths - 1) + ] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + # sample randomly masked indices batch["mask_time_indices"] = _compute_mask_indices( - (batch["input_values"].shape[0], mask_indices_seq_length), + (batch_size, mask_indices_seq_length), self.model.config.mask_time_prob, self.model.config.mask_time_length, device=batch["input_values"].device, + attention_mask=attention_mask, min_masks=2, ) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index a946f78079..05626267ff 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -47,6 +47,7 @@ def _compute_mask_indices( mask_prob: float, mask_length: int, device: torch.device, + attention_mask: Optional[torch.tensor] = None, min_masks: int = 0, ) -> torch.tensor: """ @@ -813,7 +814,10 @@ class HubertModel(HubertPreTrainedModel): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( - self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): """ Masks extracted features along time axis and/or along feature axis according to `SpecAugment @@ -836,6 +840,7 @@ class HubertModel(HubertPreTrainedModel): mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, device=hidden_states.device, + attention_mask=attention_mask, min_masks=2, ) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) @@ -847,6 +852,7 @@ class HubertModel(HubertPreTrainedModel): mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, device=hidden_states.device, + attention_mask=attention_mask, ) hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 12764a40ac..1e463234a2 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -107,6 +107,7 @@ def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, + attention_mask: Optional[np.ndarray] = None, min_masks: int = 0, ) -> np.ndarray: """ @@ -166,6 +167,10 @@ def _compute_mask_indices( # scatter indices to mask np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + if attention_mask is not None: + # make sure padded input ids cannot be masked + spec_aug_mask = np.where(attention_mask, spec_aug_mask, False) + return spec_aug_mask @@ -873,6 +878,7 @@ class FlaxWav2Vec2Module(nn.Module): """ extract_features = self.feature_extractor(input_values) + # make sure that no loss is computed on padded inputs if attention_mask is not None: # compute real output lengths according to convolution formula output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4")) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 3997813667..ea9b6cc592 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -120,6 +120,7 @@ def _compute_mask_indices( mask_prob: float, mask_length: int, device: torch.device, + attention_mask: Optional[torch.tensor] = None, min_masks: int = 0, ) -> torch.tensor: """ @@ -179,6 +180,10 @@ def _compute_mask_indices( # scatter indices to mask spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True) + if attention_mask is not None: + # make sure padded input ids cannot be masked + spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False) + return spec_aug_mask @@ -950,7 +955,10 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): self.init_weights() def _mask_hidden_states( - self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): """ Masks extracted features along time axis and/or along feature axis according to `SpecAugment @@ -973,6 +981,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, device=hidden_states.device, + attention_mask=attention_mask, min_masks=2, ) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) @@ -984,6 +993,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, device=hidden_states.device, + attention_mask=attention_mask, ) hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 @@ -1049,7 +1059,9 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() hidden_states, extract_features = self.feature_projection(extract_features) - hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) encoder_outputs = self.encoder( hidden_states, diff --git a/tests/test_modeling_flax_wav2vec2.py b/tests/test_modeling_flax_wav2vec2.py index 9b33a1d2ba..b8d27b6190 100644 --- a/tests/test_modeling_flax_wav2vec2.py +++ b/tests/test_modeling_flax_wav2vec2.py @@ -245,6 +245,24 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase): for batch_sum in mask.sum(axis=-1): self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) + def test_compute_mask_indices_attn_mask_overlap(self): + batch_size = 4 + sequence_length = 80 + mask_prob = 0.5 + mask_length = 4 + + attention_mask = np.ones((batch_size, sequence_length), dtype=np.int32) + attention_mask[:2, sequence_length // 2 :] = 0 + + mask = _compute_mask_indices( + (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask + ) + + for batch_sum in mask.sum(axis=-1): + self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) + + self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0) + def test_compute_perplexity(self): probs = np.arange(100).reshape(2, 5, 10) / 100 diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index 206a0cbeed..cb54132dd5 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -580,6 +580,24 @@ class Wav2Vec2UtilsTest(unittest.TestCase): for batch_sum in mask.sum(axis=-1): self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) + def test_compute_mask_indices_attn_mask_overlap(self): + batch_size = 4 + sequence_length = 80 + mask_prob = 0.5 + mask_length = 4 + + attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device) + attention_mask[:2, sequence_length // 2 :] = 0 + + mask = _compute_mask_indices( + (batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask + ) + + for batch_sum in mask.sum(axis=-1): + self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) + + self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0) + def test_compute_perplexity(self): probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100