From 575768183759c5d3dc052a0af26818a81521e82a Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:56:26 +0100 Subject: [PATCH] Less flaky for `TimmBackboneModelTest::test_batching_equivalence` (#35971) * fix * remove is_flaky * fix --------- Co-authored-by: ydshieh --- .../timm_backbone/test_modeling_timm_backbone.py | 10 ++++------ tests/test_modeling_common.py | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index 296a38c176..737b1ea3c5 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -18,7 +18,7 @@ import inspect import unittest from transformers import AutoBackbone -from transformers.testing_utils import is_flaky, require_timm, require_torch, torch_device +from transformers.testing_utils import require_timm, require_torch, torch_device from transformers.utils.import_utils import is_torch_available from ...test_backbone_common import BackboneTesterMixin @@ -115,11 +115,9 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste def test_config(self): self.config_tester.run_common_tests() - @is_flaky( - description="`TimmBackbone` has no `_init_weights`. Timm's way of weight init. seems to give larger magnitude in the intermediate values during `forward`." - ) - def test_batching_equivalence(self): - super().test_batching_equivalence() + # `TimmBackbone` has no `_init_weights`. Timm's way of weight init. seems to give larger magnitude in the intermediate values during `forward`. + def test_batching_equivalence(self, atol=1e-4, rtol=1e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) def test_timm_transformer_backbone_equivalence(self): timm_checkpoint = "resnet18" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ba996a966c..0f47767e41 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -768,7 +768,7 @@ class ModelTesterMixin: else: check_determinism(first, second) - def test_batching_equivalence(self): + def test_batching_equivalence(self, atol=1e-5, rtol=1e-5): """ Tests that the model supports batching and that the output is the nearly the same for the same input in different batch sizes. @@ -812,7 +812,7 @@ class ModelTesterMixin: torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}" ) try: - torch.testing.assert_close(batched_row, single_row_object, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(batched_row, single_row_object, atol=atol, rtol=rtol) except AssertionError as e: msg = f"Batched and Single row outputs are not equal in {model_name} for key={key}.\n\n" msg += str(e)