From d979cf6efdff313f7ea1218914e506c725c453d4 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 13 Mar 2023 19:46:01 +0100 Subject: [PATCH] [`Whiper`] add `get_input_embeddings` to `WhisperForAudioClassification` (#22133) * add `get_input_embeddings` to `WhisperForAudioClassification` * add common tests * fix another common test * Update tests/models/whisper/test_modeling_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../models/whisper/modeling_whisper.py | 20 ++++++++++++- tests/models/whisper/test_modeling_whisper.py | 30 ++++++++++++++++--- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 94d0be4047..e29802d444 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -767,6 +767,12 @@ class WhisperEncoder(WhisperPreTrainedModel): param.requires_grad = False self._requires_grad = False + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + def forward( self, input_features, @@ -1023,7 +1029,10 @@ class WhisperDecoder(WhisperPreTrainedModel): ) # embed positions - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + if input_ids is not None: + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + else: + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1330,6 +1339,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.proj_out = new_embeddings + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() + def freeze_encoder(self): """ Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will @@ -1635,6 +1647,12 @@ class WhisperForAudioClassification(WhisperPreTrainedModel): """ self.encoder._freeze_parameters() + def get_input_embeddings(self) -> nn.Module: + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Module): + self.encoder.set_input_embeddings(value) + @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index d4b252398f..8524c5b42c 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -357,9 +357,24 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi return config, input_ids, None, max_length - # not implemented currently def test_inputs_embeds(self): - pass + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + decoder_input_ids = inputs.pop("decoder_input_ids", None) + inputs.pop("decoder_attention_mask", None) + + wte = model.get_input_embeddings() + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] # training is not supported yet def test_training(self): @@ -1566,9 +1581,16 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. self.assertTrue((outputs_embeds == outputs).all()) - # WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented + # Needs to override as the encoder input embedding is a Conv1d def test_model_common_attributes(self): - pass + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Conv1d)) + model.set_input_embeddings(torch.nn.Conv1d(10, 10, 3)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, torch.nn.Conv1d)) # WhisperEncoder cannot resize token embeddings since it has no tokens embeddings def test_resize_tokens_embeddings(self):