[tests] remove pt_tf equivalence tests (#36253)

This commit is contained in:
Joao Gante
2025-02-19 11:55:11 +00:00
committed by GitHub
parent 1a81d774b1
commit 0863eef248
60 changed files with 56 additions and 2438 deletions

View File

@@ -24,7 +24,7 @@ import numpy as np
import requests
from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
from transformers.testing_utils import is_flaky, is_pt_tf_cross_test, require_torch, require_vision, slow, torch_device
from transformers.testing_utils import is_flaky, require_torch, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -166,18 +166,6 @@ class GroupViTVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_batching_equivalence(self):
super().test_batching_equivalence()
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import tensorflow as tf
seed = 338
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
tf.random.set_seed(seed)
return super().test_pt_tf_model_equivalence()
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -595,22 +583,6 @@ class GroupViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def test_model_get_set_embeddings(self):
pass
# overwritten from parent as this equivalent test needs a specific `seed` and hard to get a good one!
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-5, name="outputs", attributes=None):
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes)
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import tensorflow as tf
seed = 163
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
tf.random.set_seed(seed)
return super().test_pt_tf_model_equivalence()
# override as the `logit_scale` parameter initilization is different for GROUPVIT
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()