[tests] remove flax-pt equivalence and cross tests (#36283)

This commit is contained in:
Joao Gante
2025-02-19 15:13:27 +00:00
committed by GitHub
parent fa8cdccd91
commit 99adc74462
39 changed files with 33 additions and 3103 deletions

View File

@@ -22,7 +22,7 @@ from datasets import load_dataset
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import Data2VecAudioConfig, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_soundfile, require_torch, slow, torch_device
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init
@@ -442,16 +442,6 @@ class Data2VecAudioModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
def test_model_get_set_embeddings(self):
pass
@is_pt_flax_cross_test
# non-robust architecture does not exist in Flax
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
# non-robust architecture does not exist in Flax
def test_equivalence_pt_to_flax(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True