diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py index 30767e425a..e00a57240a 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -250,13 +250,6 @@ class FlaxSpeechEncoderDecoderModule(nn.Module): def _get_decoder_module(self): return self.decoder - def freeze_feature_encoder(self): - """ - Calling this function will disable the gradient computation for the feature encoder of the speech encoder in - order that its parameters are not updated during training. - """ - self.encoder.freeze_feature_encoder() - def __call__( self, inputs, @@ -269,6 +262,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module): output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, + freeze_feature_encoder: bool = False, ): if encoder_outputs is None: encoder_outputs = self.encoder( @@ -278,6 +272,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module): output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, + freeze_feature_encoder=freeze_feature_encoder, ) encoder_hidden_states = encoder_outputs[0] @@ -448,6 +443,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, + freeze_feature_encoder: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): @@ -493,6 +489,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, + freeze_feature_encoder=freeze_feature_encoder, rngs=rngs, method=_encoder_forward, ) @@ -644,6 +641,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, + freeze_feature_encoder: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): @@ -705,6 +703,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, + freeze_feature_encoder=freeze_feature_encoder, rngs=rngs, ) diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index f204dae530..7bf7e0af0a 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -28,6 +28,10 @@ from ..wav2vec2.test_modeling_flax_wav2vec2 import FlaxWav2Vec2ModelTester if is_flax_available(): + import jax + import jax.numpy as jnp + from flax.training.common_utils import onehot + from flax.traverse_util import flatten_dict from transformers import ( FlaxBartForCausalLM, FlaxGPT2LMHeadModel, @@ -275,6 +279,84 @@ class FlaxEncoderDecoderMixin: generated_sequences = generated_output.sequences self.assertEqual(generated_sequences.shape, (inputs.shape[0],) + (decoder_config.max_length,)) + def check_freeze_feature_encoder( + self, + config, + inputs, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) + enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config) + params = enc_dec_model.params + + def cross_entropy(logits, labels): + return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) + + # define a dummy loss function for computing the loss over a forward pass + def compute_loss( + params, + inputs, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + freeze_feature_encoder: bool = False, + ): + outputs_enc_dec = enc_dec_model( + inputs=inputs, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + freeze_feature_encoder=freeze_feature_encoder, + params=params, + ) + logits = outputs_enc_dec.logits + vocab_size = logits.shape[-1] + loss = cross_entropy(logits, onehot(labels=decoder_input_ids, num_classes=vocab_size)).sum() + return loss + + # transform the loss function to get the gradients + grad_fn = jax.value_and_grad(compute_loss) + + # compute the loss and gradients for the unfrozen model + loss, grads = grad_fn( + params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=False + ) + + # compare to the loss and gradients for the frozen model + loss_frozen, grads_frozen = grad_fn( + params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=True + ) + + self.assert_almost_equals(loss, loss_frozen, 1e-5) + + grads = flatten_dict(grads) + grads_frozen = flatten_dict(grads_frozen) + + # ensure that the dicts of gradients contain the same keys + self.assertEqual(grads.keys(), grads_frozen.keys()) + + # ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers + feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k) + feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k) + + for feature_extractor_grad, feature_extractor_grad_frozen in zip( + feature_extractor_grads, feature_extractor_grads_frozen + ): + self.assertTrue((feature_extractor_grad_frozen == 0.0).all()) + self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-8) + + # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor' + grads = tuple(grads[k] for k in grads if "feature_extractor" not in k) + grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k) + + for grad, grad_frozen in zip(grads, grads_frozen): + self.assert_almost_equals(grad, grad_frozen, 1e-8) + def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): pt_model.to(torch_device) @@ -367,13 +449,21 @@ class FlaxEncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_output_attentions(**input_ids_dict) + def test_freeze_feature_encoder(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_freeze_feature_encoder(**input_ids_dict) + def test_encoder_decoder_model_generate(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_generate(**input_ids_dict) def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): diff = np.abs((a - b)).max() - self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") + self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).") + + def assert_difference(self, a: np.ndarray, b: np.ndarray, tol: float): + diff = np.abs((a - b)).min() + self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).") @is_pt_flax_cross_test def test_pt_flax_equivalence(self):