[tests] remove pt_tf equivalence tests (#36253)
This commit is contained in:
@@ -24,7 +24,7 @@ import numpy as np
|
||||
import requests
|
||||
|
||||
from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
|
||||
from transformers.testing_utils import is_flaky, is_pt_tf_cross_test, require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import is_flaky, require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -166,18 +166,6 @@ class GroupViTVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_batching_equivalence(self):
|
||||
super().test_batching_equivalence()
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
import tensorflow as tf
|
||||
|
||||
seed = 338
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
tf.random.set_seed(seed)
|
||||
return super().test_pt_tf_model_equivalence()
|
||||
|
||||
def test_model_get_set_embeddings(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -595,22 +583,6 @@ class GroupViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
# overwritten from parent as this equivalent test needs a specific `seed` and hard to get a good one!
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-5, name="outputs", attributes=None):
|
||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes)
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
import tensorflow as tf
|
||||
|
||||
seed = 163
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
tf.random.set_seed(seed)
|
||||
return super().test_pt_tf_model_equivalence()
|
||||
|
||||
# override as the `logit_scale` parameter initilization is different for GROUPVIT
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user