Less flaky for TimmBackboneModelTest::test_batching_equivalence (#35971)

* fix

* remove is_flaky

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-01-30 16:56:26 +01:00
committed by GitHub
parent e320d5542e
commit 5757681837
2 changed files with 6 additions and 8 deletions

View File

@@ -18,7 +18,7 @@ import inspect
import unittest import unittest
from transformers import AutoBackbone from transformers import AutoBackbone
from transformers.testing_utils import is_flaky, require_timm, require_torch, torch_device from transformers.testing_utils import require_timm, require_torch, torch_device
from transformers.utils.import_utils import is_torch_available from transformers.utils.import_utils import is_torch_available
from ...test_backbone_common import BackboneTesterMixin from ...test_backbone_common import BackboneTesterMixin
@@ -115,11 +115,9 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@is_flaky( # `TimmBackbone` has no `_init_weights`. Timm's way of weight init. seems to give larger magnitude in the intermediate values during `forward`.
description="`TimmBackbone` has no `_init_weights`. Timm's way of weight init. seems to give larger magnitude in the intermediate values during `forward`." def test_batching_equivalence(self, atol=1e-4, rtol=1e-4):
) super().test_batching_equivalence(atol=atol, rtol=rtol)
def test_batching_equivalence(self):
super().test_batching_equivalence()
def test_timm_transformer_backbone_equivalence(self): def test_timm_transformer_backbone_equivalence(self):
timm_checkpoint = "resnet18" timm_checkpoint = "resnet18"

View File

@@ -768,7 +768,7 @@ class ModelTesterMixin:
else: else:
check_determinism(first, second) check_determinism(first, second)
def test_batching_equivalence(self): def test_batching_equivalence(self, atol=1e-5, rtol=1e-5):
""" """
Tests that the model supports batching and that the output is the nearly the same for the same input in Tests that the model supports batching and that the output is the nearly the same for the same input in
different batch sizes. different batch sizes.
@@ -812,7 +812,7 @@ class ModelTesterMixin:
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}" torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
) )
try: try:
torch.testing.assert_close(batched_row, single_row_object, atol=1e-5, rtol=1e-5) torch.testing.assert_close(batched_row, single_row_object, atol=atol, rtol=rtol)
except AssertionError as e: except AssertionError as e:
msg = f"Batched and Single row outputs are not equal in {model_name} for key={key}.\n\n" msg = f"Batched and Single row outputs are not equal in {model_name} for key={key}.\n\n"
msg += str(e) msg += str(e)