Freeze FlaxWav2Vec2 Feature Encoder (#15873)
* Freeze FlaxWav2Vec2 Feature Encoder * add to all module apply * add backprop test
This commit is contained in:
@@ -229,6 +229,47 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
def test_freeze_feature_encoder(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
|
||||
model = FlaxWav2Vec2ForPreTraining(config)
|
||||
|
||||
outputs = model(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
freeze_feature_encoder=False,
|
||||
)
|
||||
|
||||
outputs_frozen = model(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
freeze_feature_encoder=True,
|
||||
)
|
||||
|
||||
# dummy loss function
|
||||
def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8):
|
||||
# compute cosine similarity of projected and projected_quantized states
|
||||
cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon)
|
||||
loss = cosine_sim.sum()
|
||||
return loss
|
||||
|
||||
# transform the loss function to get the gradients
|
||||
grad_fn = jax.value_and_grad(compute_loss)
|
||||
|
||||
# compute loss and gradients for unfrozen model
|
||||
loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states)
|
||||
|
||||
# compare to loss and gradients for frozen model
|
||||
loss_frozen, grads_frozen = grad_fn(outputs_frozen.projected_states, outputs_frozen.projected_quantized_states)
|
||||
|
||||
self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5)
|
||||
self.assertEqual(grads.shape, grads_frozen.shape)
|
||||
max_diff = np.amax(np.abs(grads - grads_frozen))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user