From 925fc57b70fe21d4edd457925985011d632d63ce Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 28 Mar 2022 11:56:54 +0200 Subject: [PATCH] [Flax] Improve Robustness of Back-Prop Tests (#16418) * [Flax] Improve Robustness of Back-Prop Tests * check equality of logits/outputs * make fixup --- ...st_modeling_flax_speech_encoder_decoder.py | 32 +++++++++--------- tests/wav2vec2/test_modeling_flax_wav2vec2.py | 33 +++++++++---------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 0c0295ddba..113e867f3a 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -360,20 +360,24 @@ class FlaxEncoderDecoderMixin: logits = outputs_enc_dec.logits vocab_size = logits.shape[-1] loss = cross_entropy(logits, onehot(labels=decoder_input_ids, num_classes=vocab_size)).sum() - return loss + return (loss, logits) # transform the loss function to get the gradients - grad_fn = jax.value_and_grad(compute_loss) + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) - # compute the loss and gradients for the unfrozen model - loss, grads = grad_fn(params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=False) + # compute the loss, logits, and gradients for the unfrozen model + (loss, logits), grads = grad_fn( + params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=False + ) - # compare to the loss and gradients for the frozen model - loss_frozen, grads_frozen = grad_fn( + # compare to the loss, logits and gradients for the frozen model + (loss_frozen, logits_frozen), grads_frozen = grad_fn( params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=True ) - self.assert_almost_equals(loss, loss_frozen, 1e-5) + # ensure that the logits and losses remain precisely equal + self.assertTrue((logits == logits_frozen).all()) + self.assertEqual(loss, loss_frozen) grads = flatten_dict(grads) grads_frozen = flatten_dict(grads_frozen) @@ -381,7 +385,7 @@ class FlaxEncoderDecoderMixin: # ensure that the dicts of gradients contain the same keys self.assertEqual(grads.keys(), grads_frozen.keys()) - # ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers + # ensure that the gradients of the feature extractor layers are precisely zero when frozen and contain non-zero entries when unfrozen feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k) feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k) @@ -389,14 +393,14 @@ class FlaxEncoderDecoderMixin: feature_extractor_grads, feature_extractor_grads_frozen ): self.assertTrue((feature_extractor_grad_frozen == 0.0).all()) - self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-5) + self.assertTrue((feature_extractor_grad > 0.0).any()) - # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor' + # ensure that the gradients of all unfrozen layers remain precisely equal, i.e. all layers excluding the frozen 'feature_extractor' grads = tuple(grads[k] for k in grads if "feature_extractor" not in k) grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k) for grad, grad_frozen in zip(grads, grads_frozen): - self.assert_almost_equals(grad, grad_frozen, 1e-5) + self.assertTrue((grad == grad_frozen).all()) def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): @@ -504,11 +508,7 @@ class FlaxEncoderDecoderMixin: def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): diff = np.abs((a - b)).max() - self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).") - - def assert_difference(self, a: np.ndarray, b: np.ndarray, tol: float): - diff = np.abs((a - b)).max() - self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).") + self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") @is_pt_flax_cross_test def test_pt_flax_equivalence(self): diff --git a/tests/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/wav2vec2/test_modeling_flax_wav2vec2.py index 064e89b7d7..b182441cb6 100644 --- a/tests/wav2vec2/test_modeling_flax_wav2vec2.py +++ b/tests/wav2vec2/test_modeling_flax_wav2vec2.py @@ -254,18 +254,23 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): outputs.projected_states, outputs.projected_quantized_states, epsilon=epsilon ) loss = cosine_sim.sum() - return loss + return loss, outputs.to_tuple() # transform the loss function to get the gradients - grad_fn = jax.value_and_grad(compute_loss) + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) - # compute loss and gradients for unfrozen model - loss, grads = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=False) + # compute loss, outputs and gradients for unfrozen model + (loss, outputs), grads = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=False) - # compare to loss and gradients for frozen model - loss_frozen, grads_frozen = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=True) + # compare to loss, outputs and gradients for frozen model + (loss_frozen, outputs_frozen), grads_frozen = grad_fn( + params, input_values, attention_mask, freeze_feature_encoder=True + ) - self.assert_almost_equals(loss, loss_frozen, 1e-5) + # ensure that the outputs and losses remain precisely equal + for output, output_frozen in zip(outputs, outputs_frozen): + self.assertTrue((output == output_frozen).all()) + self.assertEqual(loss, loss_frozen) grads = flatten_dict(grads) grads_frozen = flatten_dict(grads_frozen) @@ -273,7 +278,7 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): # ensure that the dicts of gradients contain the same keys self.assertEqual(grads.keys(), grads_frozen.keys()) - # ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers + # ensure that the gradients of the feature extractor layers are precisely zero when frozen and contain non-zero entries when unfrozen feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k) feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k) @@ -281,22 +286,14 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): feature_extractor_grads, feature_extractor_grads_frozen ): self.assertTrue((feature_extractor_grad_frozen == 0.0).all()) - self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-7) + self.assertTrue((feature_extractor_grad > 0.0).any()) # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor' grads = tuple(grads[k] for k in grads if "feature_extractor" not in k) grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k) for grad, grad_frozen in zip(grads, grads_frozen): - self.assert_almost_equals(grad, grad_frozen, 1e-7) - - def assert_difference(self, a, b, tol: float): - diff = jnp.abs((a - b)).min() - self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).") - - def assert_almost_equals(self, a, b, tol: float): - diff = jnp.abs((a - b)).max() - self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).") + self.assertTrue((grad == grad_frozen).all()) @slow def test_model_from_pretrained(self):