From 58bf882579da8a2844545ea77845eb4587803a86 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 12 Oct 2021 18:17:06 +0200 Subject: [PATCH] [Wav2Vec2] Make sure tensors are always bool for mask_indices (#13977) * correct long to bool * up * correct code --- .../models/hubert/modeling_hubert.py | 4 +-- .../models/wav2vec2/modeling_wav2vec2.py | 4 +-- tests/test_modeling_wav2vec2.py | 27 +++++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 5bc1fc2345..2aa1626871 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -907,7 +907,7 @@ class HubertModel(HubertPreTrainedModel): attention_mask=attention_mask, min_masks=2, ) - mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: @@ -917,7 +917,7 @@ class HubertModel(HubertPreTrainedModel): mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, ) - mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[ + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)[ :, None ].expand(-1, sequence_length, -1) hidden_states[mask_feature_indices] = 0 diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 7a566654ff..e62186eb99 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1100,7 +1100,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): attention_mask=attention_mask, min_masks=2, ) - mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: @@ -1110,7 +1110,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, ) - mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[ + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)[ :, None ].expand(-1, sequence_length, -1) hidden_states[mask_feature_indices] = 0 diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index 2a62ea97f1..13ef539d46 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -738,6 +738,33 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): self.assertEqual(logits.shape, (4, 1498, 32)) + def test_mask_time_feature_prob_ctc_single_batch(self): + model = Wav2Vec2ForCTC.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", + mask_time_prob=0.2, + mask_feature_prob=0.2, + mask_time_length=2, + mask_feature_length=2, + ) + model.to(torch_device).train() + processor = Wav2Vec2Processor.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True + ) + + batch_duration_in_seconds = [6] + input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds] + + batch = processor( + input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt" + ) + + logits = model( + input_values=batch["input_values"].to(torch_device), + attention_mask=batch["attention_mask"].to(torch_device), + ).logits + + self.assertEqual(logits.shape, (1, 1498, 32)) + @slow def test_model_from_pretrained(self): model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")