Byebye test_batching_equivalence's flakiness (#35729)

* fix

* fix

* skip

* better error message

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-01-21 13:11:33 +01:00
committed by GitHub
parent 78f5ee0217
commit fd8d61fdb2
18 changed files with 92 additions and 50 deletions

View File

@@ -24,7 +24,7 @@ import numpy as np
import requests
from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
from transformers.testing_utils import is_pt_tf_cross_test, require_torch, require_vision, slow, torch_device
from transformers.testing_utils import is_flaky, is_pt_tf_cross_test, require_torch, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -162,6 +162,10 @@ class GroupViTVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@is_flaky(description="The `index` computed with `max()` in `hard_softmax` is not stable.")
def test_batching_equivalence(self):
super().test_batching_equivalence()
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import tensorflow as tf
@@ -571,6 +575,10 @@ class GroupViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def test_config(self):
self.config_tester.run_common_tests()
@is_flaky(description="The `index` computed with `max()` in `hard_softmax` is not stable.")
def test_batching_equivalence(self):
super().test_batching_equivalence()
@unittest.skip(reason="hidden_states are tested in individual model tests")
def test_hidden_states_output(self):
pass