Temporarily increase tol for PT-FLAX whisper tests (#23288)
This commit is contained in:
@@ -248,6 +248,10 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
self.assertEqual(jitted_output.shape, output.shape)
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
|
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||||
|
# We override with a slightly higher tol value, as test recently became flaky
|
||||||
|
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||||
|
|
||||||
# overwrite because of `input_features`
|
# overwrite because of `input_features`
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
def test_save_load_bf16_to_base_pt(self):
|
def test_save_load_bf16_to_base_pt(self):
|
||||||
|
|||||||
@@ -828,6 +828,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
# We override with a slightly higher tol value, as test recently became flaky
|
# We override with a slightly higher tol value, as test recently became flaky
|
||||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||||
|
|
||||||
|
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||||
|
# We override with a slightly higher tol value, as test recently became flaky
|
||||||
|
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||||
|
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
def test_equivalence_pt_to_flax(self):
|
def test_equivalence_pt_to_flax(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user