[tests] remove flax-pt equivalence and cross tests (#36283)

This commit is contained in:
Joao Gante
2025-02-19 15:13:27 +00:00
committed by GitHub
parent fa8cdccd91
commit 99adc74462
39 changed files with 33 additions and 3103 deletions

View File

@@ -24,7 +24,6 @@ from datasets import load_dataset
from transformers import Wav2Vec2ConformerConfig, is_torch_available
from transformers.testing_utils import (
is_flaky,
is_pt_flax_cross_test,
require_torch,
require_torch_accelerator,
require_torch_fp16,
@@ -535,16 +534,6 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest
def test_model_get_set_embeddings(self):
pass
@is_pt_flax_cross_test
@unittest.skip(reason="Non-robust architecture does not exist in Flax")
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
@unittest.skip(reason="Non-robust architecture does not exist in Flax")
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True