From 1a62b25caf06cd4a13af2db1e94abce9969a1d9b Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 7 Mar 2022 18:10:15 +0100 Subject: [PATCH] Backprop Test for Freeze FlaxWav2Vec2 Feature Encoder (#15938) * Backprop Test for Freeze FlaxWav2Vec2 Feature Encoder * remove jnp.ndarray type suggestion * assert frozen grads are precisely zero --- tests/wav2vec2/test_modeling_flax_wav2vec2.py | 68 +++++++++++++------ 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/tests/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/wav2vec2/test_modeling_flax_wav2vec2.py index 42f904b4cc..064e89b7d7 100644 --- a/tests/wav2vec2/test_modeling_flax_wav2vec2.py +++ b/tests/wav2vec2/test_modeling_flax_wav2vec2.py @@ -37,6 +37,7 @@ if is_flax_available(): import jax import jax.numpy as jnp import optax + from flax.traverse_util import flatten_dict from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor from transformers.models.wav2vec2.modeling_flax_wav2vec2 import ( FlaxWav2Vec2ForCTC, @@ -236,23 +237,22 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): attention_mask = inputs_dict["attention_mask"] model = FlaxWav2Vec2ForPreTraining(config) - - outputs = model( - input_values, - attention_mask=attention_mask, - freeze_feature_encoder=False, - ) - - outputs_frozen = model( - input_values, - attention_mask=attention_mask, - freeze_feature_encoder=True, - ) + params = model.params # dummy loss function - def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8): + def compute_loss( + params, input_values, attention_mask, freeze_feature_encoder: bool = False, epsilon: float = 1e-8 + ): + outputs = model( + input_values, + attention_mask=attention_mask, + freeze_feature_encoder=freeze_feature_encoder, + params=params, + ) # compute cosine similarity of projected and projected_quantized states - cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon) + cosine_sim = optax.cosine_similarity( + outputs.projected_states, outputs.projected_quantized_states, epsilon=epsilon + ) loss = cosine_sim.sum() return loss @@ -260,15 +260,43 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): grad_fn = jax.value_and_grad(compute_loss) # compute loss and gradients for unfrozen model - loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states) + loss, grads = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=False) # compare to loss and gradients for frozen model - loss_frozen, grads_frozen = grad_fn(outputs_frozen.projected_states, outputs_frozen.projected_quantized_states) + loss_frozen, grads_frozen = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=True) - self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5) - self.assertEqual(grads.shape, grads_frozen.shape) - max_diff = np.amax(np.abs(grads - grads_frozen)) - self.assertLessEqual(max_diff, 1e-5) + self.assert_almost_equals(loss, loss_frozen, 1e-5) + + grads = flatten_dict(grads) + grads_frozen = flatten_dict(grads_frozen) + + # ensure that the dicts of gradients contain the same keys + self.assertEqual(grads.keys(), grads_frozen.keys()) + + # ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers + feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k) + feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k) + + for feature_extractor_grad, feature_extractor_grad_frozen in zip( + feature_extractor_grads, feature_extractor_grads_frozen + ): + self.assertTrue((feature_extractor_grad_frozen == 0.0).all()) + self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-7) + + # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor' + grads = tuple(grads[k] for k in grads if "feature_extractor" not in k) + grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k) + + for grad, grad_frozen in zip(grads, grads_frozen): + self.assert_almost_equals(grad, grad_frozen, 1e-7) + + def assert_difference(self, a, b, tol: float): + diff = jnp.abs((a - b)).min() + self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).") + + def assert_almost_equals(self, a, b, tol: float): + diff = jnp.abs((a - b)).max() + self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).") @slow def test_model_from_pretrained(self):