[tests] remove flax-pt equivalence and cross tests (#36283)

This commit is contained in:
Joao Gante
2025-02-19 15:13:27 +00:00
committed by GitHub
parent fa8cdccd91
commit 99adc74462
39 changed files with 33 additions and 3103 deletions

View File

@@ -24,9 +24,7 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import (
CaptureLogger,
is_flaky,
is_librosa_available,
is_pt_flax_cross_test,
is_pyctcdecode_available,
require_flax,
require_librosa,
@@ -350,11 +348,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
outputs = model(np.ones((1, 1024), dtype="f4"))
self.assertIsNotNone(outputs)
@is_pt_flax_cross_test
@is_flaky()
def test_equivalence_pt_to_flax(self):
super().test_equivalence_pt_to_flax()
@require_flax
class FlaxWav2Vec2UtilsTest(unittest.TestCase):

View File

@@ -31,7 +31,6 @@ from transformers.testing_utils import (
CaptureLogger,
cleanup,
is_flaky,
is_pt_flax_cross_test,
is_pyctcdecode_available,
is_torchaudio_available,
require_flash_attn,
@@ -569,16 +568,6 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def test_model_get_set_embeddings(self):
pass
@is_pt_flax_cross_test
@unittest.skip(reason="Non-rubst architecture does not exist in Flax")
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
@unittest.skip(reason="Non-rubst architecture does not exist in Flax")
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True