From e231c729063d88b0c2bad4ac2d461c4e46eac9ab Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 25 Mar 2022 17:46:53 +0100 Subject: [PATCH] [FlaxSpeechEncoderDecoder] Fix feature extractor gradient test (#16407) --- .../test_modeling_flax_speech_encoder_decoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 4ceea974f3..0c0295ddba 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -389,14 +389,14 @@ class FlaxEncoderDecoderMixin: feature_extractor_grads, feature_extractor_grads_frozen ): self.assertTrue((feature_extractor_grad_frozen == 0.0).all()) - self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-10) + self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-5) # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor' grads = tuple(grads[k] for k in grads if "feature_extractor" not in k) grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k) for grad, grad_frozen in zip(grads, grads_frozen): - self.assert_almost_equals(grad, grad_frozen, 1e-10) + self.assert_almost_equals(grad, grad_frozen, 1e-5) def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): @@ -507,7 +507,7 @@ class FlaxEncoderDecoderMixin: self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).") def assert_difference(self, a: np.ndarray, b: np.ndarray, tol: float): - diff = np.abs((a - b)).min() + diff = np.abs((a - b)).max() self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).") @is_pt_flax_cross_test