From 0f5488f79fabfaa0c49226c96409ab11d661396b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 7 Oct 2021 18:07:32 +0200 Subject: [PATCH] [Wav2Vec2] Fix mask_feature_prob (#13921) * up * overwrite hubert --- .../models/hubert/modeling_hubert.py | 1 - .../models/wav2vec2/modeling_wav2vec2.py | 1 - tests/test_modeling_wav2vec2.py | 93 +++++++++++++++++++ 3 files changed, 93 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index ddca07f597..00cfb89600 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -877,7 +877,6 @@ class HubertModel(HubertPreTrainedModel): mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, device=hidden_states.device, - attention_mask=attention_mask, ) hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 7ede44928b..6b4a252282 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1014,7 +1014,6 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, device=hidden_states.device, - attention_mask=attention_mask, ) hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index d1094a81d2..9c0dc7ee9e 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -17,6 +17,7 @@ import math import unittest +import numpy as np import pytest from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask @@ -433,6 +434,52 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: module.masked_spec_embed.data.fill_(3) + def test_mask_feature_prob_ctc(self): + model = Wav2Vec2ForCTC.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", mask_feature_prob=0.2, mask_feature_length=2 + ) + model.to(torch_device).train() + processor = Wav2Vec2Processor.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True + ) + + batch_duration_in_seconds = [1, 3, 2, 6] + input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds] + + batch = processor( + input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt" + ) + + logits = model( + input_values=batch["input_values"].to(torch_device), + attention_mask=batch["attention_mask"].to(torch_device), + ).logits + + self.assertEqual(logits.shape, (4, 1498, 32)) + + def test_mask_time_prob_ctc(self): + model = Wav2Vec2ForCTC.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", mask_time_prob=0.2, mask_time_length=2 + ) + model.to(torch_device).train() + processor = Wav2Vec2Processor.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True + ) + + batch_duration_in_seconds = [1, 3, 2, 6] + input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds] + + batch = processor( + input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt" + ) + + logits = model( + input_values=batch["input_values"].to(torch_device), + attention_mask=batch["attention_mask"].to(torch_device), + ).logits + + self.assertEqual(logits.shape, (4, 1498, 32)) + @slow def test_model_from_pretrained(self): model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") @@ -620,6 +667,52 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): # loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted self.assertTrue(loss.detach().item() <= loss_more_masked.detach().item()) + def test_mask_feature_prob_ctc(self): + model = Wav2Vec2ForCTC.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", mask_feature_prob=0.2, mask_feature_length=2 + ) + model.to(torch_device).train() + processor = Wav2Vec2Processor.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True + ) + + batch_duration_in_seconds = [1, 3, 2, 6] + input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds] + + batch = processor( + input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt" + ) + + logits = model( + input_values=batch["input_values"].to(torch_device), + attention_mask=batch["attention_mask"].to(torch_device), + ).logits + + self.assertEqual(logits.shape, (4, 1498, 32)) + + def test_mask_time_prob_ctc(self): + model = Wav2Vec2ForCTC.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", mask_time_prob=0.2, mask_time_length=2 + ) + model.to(torch_device).train() + processor = Wav2Vec2Processor.from_pretrained( + "hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True + ) + + batch_duration_in_seconds = [1, 3, 2, 6] + input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds] + + batch = processor( + input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt" + ) + + logits = model( + input_values=batch["input_values"].to(torch_device), + attention_mask=batch["attention_mask"].to(torch_device), + ).logits + + self.assertEqual(logits.shape, (4, 1498, 32)) + @slow def test_model_from_pretrained(self): model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")