[Whisper] Add SpecAugment (#21298)
* Return and rescale attention_mask * Add SpecAugment to Whisper modeling * Fix test * Update docstring * Add SpecAug related parameters to model config * Add the _mask_input_features function to doc * Fix quality * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Remove dev comments * Add test * Resolve conflict * feat: mask {feature, time} prob fast tests * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: sanchit-gandhi <sanchit@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -383,6 +383,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
|
||||
expected_arg_names = [
|
||||
"input_features",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
@@ -909,6 +910,34 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
def test_mask_feature_prob(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.mask_feature_prob = 0.2
|
||||
config.mask_feature_length = 2
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# forward pass
|
||||
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
|
||||
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
|
||||
|
||||
def test_mask_time_prob(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.mask_time_prob = 0.2
|
||||
config.mask_time_length = 2
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# forward pass
|
||||
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
|
||||
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@@ -1289,3 +1318,38 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_specaugment_librispeech(self):
|
||||
torch_device = "cpu"
|
||||
set_seed(0)
|
||||
# Apply SpecAugment
|
||||
model = WhisperModel.from_pretrained("openai/whisper-tiny", apply_spec_augment=True)
|
||||
# Set model to training mode to enable SpecAugment
|
||||
model.train()
|
||||
model.to(torch_device)
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=torch.tensor([[50258, 50259, 50359]]),
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=False,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor(
|
||||
[
|
||||
0.9362, -4.7105, 5.0879, 3.9642, 1.0013, -6.0096, 4.7285, -3.1847,
|
||||
-0.8648, 1.9631, 6.2653, 3.6936, 0.3575, -4.5818, 3.0564, 7.8712,
|
||||
2.9951, 0.6848, 9.9497, -2.6638, 1.1571, -6.8546, -1.4333, -7.7584,
|
||||
1.1200, 3.9030, 4.4655, -4.4919, -1.1703, 9.6241
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user