[Whisper] Freeze params of encoder (#19527)
* [Whisper] Freeze params of encoder * add tests
This commit is contained in:
@@ -182,9 +182,12 @@ class WhisperModelTester:
|
||||
|
||||
return input_lengths
|
||||
|
||||
def create_and_check_model_forward(self, config, inputs_dict):
|
||||
def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False):
|
||||
model = WhisperModel(config=config).to(torch_device).eval()
|
||||
|
||||
if freeze_encoder:
|
||||
model.freeze_encoder()
|
||||
|
||||
input_features = inputs_dict["input_features"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
|
||||
@@ -289,6 +292,26 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
def test_model_forward_with_frozen_encoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs, freeze_encoder=True)
|
||||
|
||||
def test_requires_grad_with_frozen_encoder(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.freeze_encoder()
|
||||
|
||||
try:
|
||||
encoder_grads = [param.requires_grad for param in model.encoder.parameters()]
|
||||
decoder_grads = [param.requires_grad for param in model.decoder.parameters()]
|
||||
except AttributeError:
|
||||
encoder_grads = [param.requires_grad for param in model.model.encoder.parameters()]
|
||||
decoder_grads = [param.requires_grad for param in model.model.decoder.parameters()]
|
||||
|
||||
self.assertFalse(all(encoder_grads))
|
||||
self.assertTrue(all(decoder_grads))
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user