Backprop Test for Freeze FlaxWav2Vec2 Feature Encoder (#15938)
* Backprop Test for Freeze FlaxWav2Vec2 Feature Encoder * remove jnp.ndarray type suggestion * assert frozen grads are precisely zero
This commit is contained in:
@@ -37,6 +37,7 @@ if is_flax_available():
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import optax
|
import optax
|
||||||
|
from flax.traverse_util import flatten_dict
|
||||||
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||||
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
|
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
|
||||||
FlaxWav2Vec2ForCTC,
|
FlaxWav2Vec2ForCTC,
|
||||||
@@ -236,23 +237,22 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
attention_mask = inputs_dict["attention_mask"]
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
|
||||||
model = FlaxWav2Vec2ForPreTraining(config)
|
model = FlaxWav2Vec2ForPreTraining(config)
|
||||||
|
params = model.params
|
||||||
|
|
||||||
|
# dummy loss function
|
||||||
|
def compute_loss(
|
||||||
|
params, input_values, attention_mask, freeze_feature_encoder: bool = False, epsilon: float = 1e-8
|
||||||
|
):
|
||||||
outputs = model(
|
outputs = model(
|
||||||
input_values,
|
input_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
freeze_feature_encoder=False,
|
freeze_feature_encoder=freeze_feature_encoder,
|
||||||
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs_frozen = model(
|
|
||||||
input_values,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
freeze_feature_encoder=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# dummy loss function
|
|
||||||
def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8):
|
|
||||||
# compute cosine similarity of projected and projected_quantized states
|
# compute cosine similarity of projected and projected_quantized states
|
||||||
cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon)
|
cosine_sim = optax.cosine_similarity(
|
||||||
|
outputs.projected_states, outputs.projected_quantized_states, epsilon=epsilon
|
||||||
|
)
|
||||||
loss = cosine_sim.sum()
|
loss = cosine_sim.sum()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@@ -260,15 +260,43 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
grad_fn = jax.value_and_grad(compute_loss)
|
grad_fn = jax.value_and_grad(compute_loss)
|
||||||
|
|
||||||
# compute loss and gradients for unfrozen model
|
# compute loss and gradients for unfrozen model
|
||||||
loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states)
|
loss, grads = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=False)
|
||||||
|
|
||||||
# compare to loss and gradients for frozen model
|
# compare to loss and gradients for frozen model
|
||||||
loss_frozen, grads_frozen = grad_fn(outputs_frozen.projected_states, outputs_frozen.projected_quantized_states)
|
loss_frozen, grads_frozen = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=True)
|
||||||
|
|
||||||
self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5)
|
self.assert_almost_equals(loss, loss_frozen, 1e-5)
|
||||||
self.assertEqual(grads.shape, grads_frozen.shape)
|
|
||||||
max_diff = np.amax(np.abs(grads - grads_frozen))
|
grads = flatten_dict(grads)
|
||||||
self.assertLessEqual(max_diff, 1e-5)
|
grads_frozen = flatten_dict(grads_frozen)
|
||||||
|
|
||||||
|
# ensure that the dicts of gradients contain the same keys
|
||||||
|
self.assertEqual(grads.keys(), grads_frozen.keys())
|
||||||
|
|
||||||
|
# ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers
|
||||||
|
feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k)
|
||||||
|
feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k)
|
||||||
|
|
||||||
|
for feature_extractor_grad, feature_extractor_grad_frozen in zip(
|
||||||
|
feature_extractor_grads, feature_extractor_grads_frozen
|
||||||
|
):
|
||||||
|
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
|
||||||
|
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-7)
|
||||||
|
|
||||||
|
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
|
||||||
|
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
|
||||||
|
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
|
||||||
|
|
||||||
|
for grad, grad_frozen in zip(grads, grads_frozen):
|
||||||
|
self.assert_almost_equals(grad, grad_frozen, 1e-7)
|
||||||
|
|
||||||
|
def assert_difference(self, a, b, tol: float):
|
||||||
|
diff = jnp.abs((a - b)).min()
|
||||||
|
self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).")
|
||||||
|
|
||||||
|
def assert_almost_equals(self, a, b, tol: float):
|
||||||
|
diff = jnp.abs((a - b)).max()
|
||||||
|
self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user