From f82ee109e6e58e19c21e631a2354af3b00da9a3c Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 11 May 2023 10:04:07 +0100 Subject: [PATCH] Temporary tolerance fix for flaky whipser PT-TF equiv. test (#23257) * Temp tol fix for flaky whipser test * Add equivalent update to TF tests --- tests/models/whisper/test_modeling_tf_whisper.py | 4 ++++ tests/models/whisper/test_modeling_whisper.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index d4abd8f5f0..a52994899a 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -400,6 +400,10 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC check_hidden_states_output(inputs_dict, config, model_class) + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None): + # We override with a slightly higher tol value, as test recently became flaky + super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes) + def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 0b5b375e9d..f66cafed97 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -824,6 +824,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi self.assertTrue(models_equal) + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None): + # We override with a slightly higher tol value, as test recently became flaky + super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes) + @is_pt_flax_cross_test def test_equivalence_pt_to_flax(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()