From bbd150e92f84db72e7507d0c3ce69474b2948839 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 13 Oct 2022 09:50:02 +0100 Subject: [PATCH] [Whisper] Freeze params of encoder (#19527) * [Whisper] Freeze params of encoder * add tests --- .../models/whisper/modeling_whisper.py | 19 ++++++++++++++ tests/models/whisper/test_modeling_whisper.py | 25 ++++++++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 25867582dd..35cc608e31 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -609,6 +609,11 @@ class WhisperEncoder(WhisperPreTrainedModel): # Initialize weights and apply final processing self.post_init() + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + def forward( self, input_features, @@ -991,6 +996,13 @@ class WhisperModel(WhisperPreTrainedModel): def get_decoder(self): return self.decoder + def freeze_encoder(self): + """ + Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will + not be updated during training. + """ + self.encoder._freeze_parameters() + @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class=_PROCESSOR_FOR_DOC, @@ -1109,6 +1121,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.proj_out = new_embeddings + def freeze_encoder(self): + """ + Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will + not be updated during training. + """ + self.model.encoder._freeze_parameters() + @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index bef46d5014..7907aaa1eb 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -182,9 +182,12 @@ class WhisperModelTester: return input_lengths - def create_and_check_model_forward(self, config, inputs_dict): + def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False): model = WhisperModel(config=config).to(torch_device).eval() + if freeze_encoder: + model.freeze_encoder() + input_features = inputs_dict["input_features"] decoder_input_ids = inputs_dict["decoder_input_ids"] @@ -289,6 +292,26 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_forward(*config_and_inputs) + def test_model_forward_with_frozen_encoder(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_forward(*config_and_inputs, freeze_encoder=True) + + def test_requires_grad_with_frozen_encoder(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + model.freeze_encoder() + + try: + encoder_grads = [param.requires_grad for param in model.encoder.parameters()] + decoder_grads = [param.requires_grad for param in model.decoder.parameters()] + except AttributeError: + encoder_grads = [param.requires_grad for param in model.model.encoder.parameters()] + decoder_grads = [param.requires_grad for param in model.model.decoder.parameters()] + + self.assertFalse(all(encoder_grads)) + self.assertTrue(all(decoder_grads)) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)