Freeze Feature Encoder in FlaxSpeechEncoderDecoder (#15997)
* Freeze Feature Encoder in FlaxSpeechEncoderDecoder * add backprop test
This commit is contained in:
@@ -250,13 +250,6 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
|
|||||||
def _get_decoder_module(self):
|
def _get_decoder_module(self):
|
||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
def freeze_feature_encoder(self):
|
|
||||||
"""
|
|
||||||
Calling this function will disable the gradient computation for the feature encoder of the speech encoder in
|
|
||||||
order that its parameters are not updated during training.
|
|
||||||
"""
|
|
||||||
self.encoder.freeze_feature_encoder()
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -269,6 +262,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
|
|||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
|
freeze_feature_encoder: bool = False,
|
||||||
):
|
):
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
@@ -278,6 +272,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
|
freeze_feature_encoder=freeze_feature_encoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
@@ -448,6 +443,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
|
freeze_feature_encoder: bool = False,
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
dropout_rng: PRNGKey = None,
|
dropout_rng: PRNGKey = None,
|
||||||
):
|
):
|
||||||
@@ -493,6 +489,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
deterministic=not train,
|
deterministic=not train,
|
||||||
|
freeze_feature_encoder=freeze_feature_encoder,
|
||||||
rngs=rngs,
|
rngs=rngs,
|
||||||
method=_encoder_forward,
|
method=_encoder_forward,
|
||||||
)
|
)
|
||||||
@@ -644,6 +641,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
|
freeze_feature_encoder: bool = False,
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
dropout_rng: PRNGKey = None,
|
dropout_rng: PRNGKey = None,
|
||||||
):
|
):
|
||||||
@@ -705,6 +703,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
deterministic=not train,
|
deterministic=not train,
|
||||||
|
freeze_feature_encoder=freeze_feature_encoder,
|
||||||
rngs=rngs,
|
rngs=rngs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ from ..wav2vec2.test_modeling_flax_wav2vec2 import FlaxWav2Vec2ModelTester
|
|||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from flax.training.common_utils import onehot
|
||||||
|
from flax.traverse_util import flatten_dict
|
||||||
from transformers import (
|
from transformers import (
|
||||||
FlaxBartForCausalLM,
|
FlaxBartForCausalLM,
|
||||||
FlaxGPT2LMHeadModel,
|
FlaxGPT2LMHeadModel,
|
||||||
@@ -275,6 +279,84 @@ class FlaxEncoderDecoderMixin:
|
|||||||
generated_sequences = generated_output.sequences
|
generated_sequences = generated_output.sequences
|
||||||
self.assertEqual(generated_sequences.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
self.assertEqual(generated_sequences.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
||||||
|
|
||||||
|
def check_freeze_feature_encoder(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
inputs,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||||
|
enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
|
||||||
|
params = enc_dec_model.params
|
||||||
|
|
||||||
|
def cross_entropy(logits, labels):
|
||||||
|
return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)
|
||||||
|
|
||||||
|
# define a dummy loss function for computing the loss over a forward pass
|
||||||
|
def compute_loss(
|
||||||
|
params,
|
||||||
|
inputs,
|
||||||
|
attention_mask,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
freeze_feature_encoder: bool = False,
|
||||||
|
):
|
||||||
|
outputs_enc_dec = enc_dec_model(
|
||||||
|
inputs=inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
freeze_feature_encoder=freeze_feature_encoder,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
logits = outputs_enc_dec.logits
|
||||||
|
vocab_size = logits.shape[-1]
|
||||||
|
loss = cross_entropy(logits, onehot(labels=decoder_input_ids, num_classes=vocab_size)).sum()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# transform the loss function to get the gradients
|
||||||
|
grad_fn = jax.value_and_grad(compute_loss)
|
||||||
|
|
||||||
|
# compute the loss and gradients for the unfrozen model
|
||||||
|
loss, grads = grad_fn(
|
||||||
|
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# compare to the loss and gradients for the frozen model
|
||||||
|
loss_frozen, grads_frozen = grad_fn(
|
||||||
|
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_almost_equals(loss, loss_frozen, 1e-5)
|
||||||
|
|
||||||
|
grads = flatten_dict(grads)
|
||||||
|
grads_frozen = flatten_dict(grads_frozen)
|
||||||
|
|
||||||
|
# ensure that the dicts of gradients contain the same keys
|
||||||
|
self.assertEqual(grads.keys(), grads_frozen.keys())
|
||||||
|
|
||||||
|
# ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers
|
||||||
|
feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k)
|
||||||
|
feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k)
|
||||||
|
|
||||||
|
for feature_extractor_grad, feature_extractor_grad_frozen in zip(
|
||||||
|
feature_extractor_grads, feature_extractor_grads_frozen
|
||||||
|
):
|
||||||
|
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
|
||||||
|
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-8)
|
||||||
|
|
||||||
|
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
|
||||||
|
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
|
||||||
|
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
|
||||||
|
|
||||||
|
for grad, grad_frozen in zip(grads, grads_frozen):
|
||||||
|
self.assert_almost_equals(grad, grad_frozen, 1e-8)
|
||||||
|
|
||||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||||
|
|
||||||
pt_model.to(torch_device)
|
pt_model.to(torch_device)
|
||||||
@@ -367,13 +449,21 @@ class FlaxEncoderDecoderMixin:
|
|||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||||
|
|
||||||
|
def test_freeze_feature_encoder(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_freeze_feature_encoder(**input_ids_dict)
|
||||||
|
|
||||||
def test_encoder_decoder_model_generate(self):
|
def test_encoder_decoder_model_generate(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||||
|
|
||||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||||
diff = np.abs((a - b)).max()
|
diff = np.abs((a - b)).max()
|
||||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).")
|
||||||
|
|
||||||
|
def assert_difference(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||||
|
diff = np.abs((a - b)).min()
|
||||||
|
self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).")
|
||||||
|
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
def test_pt_flax_equivalence(self):
|
def test_pt_flax_equivalence(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user