Freeze FlaxWav2Vec2 Feature Encoder (#15873)

* Freeze FlaxWav2Vec2 Feature Encoder

* add to all module apply

* add backprop test
This commit is contained in:
Sanchit Gandhi
2022-03-03 14:17:13 +01:00
committed by GitHub
parent 7b3bd1f21a
commit 3c4fbc616f
2 changed files with 54 additions and 2 deletions

View File

@@ -229,6 +229,47 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(jitted_output.shape, output.shape)
def test_freeze_feature_encoder(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_values = inputs_dict["input_values"]
attention_mask = inputs_dict["attention_mask"]
model = FlaxWav2Vec2ForPreTraining(config)
outputs = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=False,
)
outputs_frozen = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=True,
)
# dummy loss function
def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8):
# compute cosine similarity of projected and projected_quantized states
cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon)
loss = cosine_sim.sum()
return loss
# transform the loss function to get the gradients
grad_fn = jax.value_and_grad(compute_loss)
# compute loss and gradients for unfrozen model
loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states)
# compare to loss and gradients for frozen model
loss_frozen, grads_frozen = grad_fn(outputs_frozen.projected_states, outputs_frozen.projected_quantized_states)
self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5)
self.assertEqual(grads.shape, grads_frozen.shape)
max_diff = np.amax(np.abs(grads - grads_frozen))
self.assertLessEqual(max_diff, 1e-5)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: