[tests] remove pt_tf equivalence tests (#36253)

This commit is contained in:
Joao Gante
2025-02-19 11:55:11 +00:00
committed by GitHub
parent 1a81d774b1
commit 0863eef248
60 changed files with 56 additions and 2438 deletions

View File

@@ -433,10 +433,6 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_hidden_states_output(self):
pass
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
# Use a slightly higher default tol to make the tests non-flaky
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes)
@slow
def test_model_from_pretrained(self):
model_name = "facebook/sam-vit-huge"

View File

@@ -411,16 +411,6 @@ class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
model = TFSamModel.from_pretrained("facebook/sam-vit-base") # sam-vit-huge blows out our memory
self.assertIsNotNone(model)
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-4, name="outputs", attributes=None):
super().check_pt_tf_outputs(
tf_outputs=tf_outputs,
pt_outputs=pt_outputs,
model_class=model_class,
tol=tol,
name=name,
attributes=attributes,
)
def prepare_image():
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"

View File

@@ -18,7 +18,6 @@ import unittest
import numpy as np
from transformers.testing_utils import (
is_pt_tf_cross_test,
require_tf,
require_torch,
require_torchvision,
@@ -340,42 +339,3 @@ class SamProcessorEquivalenceTest(unittest.TestCase):
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()
@is_pt_tf_cross_test
def test_post_process_masks_equivalence(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
dummy_masks = np.random.randint(0, 2, size=(1, 3, 5, 5)).astype(np.float32)
tf_dummy_masks = [tf.convert_to_tensor(dummy_masks)]
pt_dummy_masks = [torch.tensor(dummy_masks)]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
tf_masks = processor.post_process_masks(
tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf"
)
pt_masks = processor.post_process_masks(
pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt"
)
self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy()))
@is_pt_tf_cross_test
def test_image_processor_equivalence(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
pt_input_feat_extract = image_processor(image_input, return_tensors="pt")["pixel_values"].numpy()
pt_input_processor = processor(images=image_input, return_tensors="pt")["pixel_values"].numpy()
tf_input_feat_extract = image_processor(image_input, return_tensors="tf")["pixel_values"].numpy()
tf_input_processor = processor(images=image_input, return_tensors="tf")["pixel_values"].numpy()
self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor))
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract))
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor))