Less flaky for TimmBackboneModelTest::test_batching_equivalence (#35971)
* fix * remove is_flaky * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,7 @@ import inspect
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoBackbone
|
from transformers import AutoBackbone
|
||||||
from transformers.testing_utils import is_flaky, require_timm, require_torch, torch_device
|
from transformers.testing_utils import require_timm, require_torch, torch_device
|
||||||
from transformers.utils.import_utils import is_torch_available
|
from transformers.utils.import_utils import is_torch_available
|
||||||
|
|
||||||
from ...test_backbone_common import BackboneTesterMixin
|
from ...test_backbone_common import BackboneTesterMixin
|
||||||
@@ -115,11 +115,9 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
@is_flaky(
|
# `TimmBackbone` has no `_init_weights`. Timm's way of weight init. seems to give larger magnitude in the intermediate values during `forward`.
|
||||||
description="`TimmBackbone` has no `_init_weights`. Timm's way of weight init. seems to give larger magnitude in the intermediate values during `forward`."
|
def test_batching_equivalence(self, atol=1e-4, rtol=1e-4):
|
||||||
)
|
super().test_batching_equivalence(atol=atol, rtol=rtol)
|
||||||
def test_batching_equivalence(self):
|
|
||||||
super().test_batching_equivalence()
|
|
||||||
|
|
||||||
def test_timm_transformer_backbone_equivalence(self):
|
def test_timm_transformer_backbone_equivalence(self):
|
||||||
timm_checkpoint = "resnet18"
|
timm_checkpoint = "resnet18"
|
||||||
|
|||||||
@@ -768,7 +768,7 @@ class ModelTesterMixin:
|
|||||||
else:
|
else:
|
||||||
check_determinism(first, second)
|
check_determinism(first, second)
|
||||||
|
|
||||||
def test_batching_equivalence(self):
|
def test_batching_equivalence(self, atol=1e-5, rtol=1e-5):
|
||||||
"""
|
"""
|
||||||
Tests that the model supports batching and that the output is the nearly the same for the same input in
|
Tests that the model supports batching and that the output is the nearly the same for the same input in
|
||||||
different batch sizes.
|
different batch sizes.
|
||||||
@@ -812,7 +812,7 @@ class ModelTesterMixin:
|
|||||||
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
|
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
torch.testing.assert_close(batched_row, single_row_object, atol=1e-5, rtol=1e-5)
|
torch.testing.assert_close(batched_row, single_row_object, atol=atol, rtol=rtol)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
msg = f"Batched and Single row outputs are not equal in {model_name} for key={key}.\n\n"
|
msg = f"Batched and Single row outputs are not equal in {model_name} for key={key}.\n\n"
|
||||||
msg += str(e)
|
msg += str(e)
|
||||||
|
|||||||
Reference in New Issue
Block a user