[tests] remove pt_tf equivalence tests (#36253)
This commit is contained in:
@@ -18,7 +18,6 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import (
|
||||
is_pt_tf_cross_test,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torchvision,
|
||||
@@ -340,42 +339,3 @@ class SamProcessorEquivalenceTest(unittest.TestCase):
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images."""
|
||||
return prepare_image_inputs()
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_post_process_masks_equivalence(self):
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
processor = SamProcessor(image_processor=image_processor)
|
||||
dummy_masks = np.random.randint(0, 2, size=(1, 3, 5, 5)).astype(np.float32)
|
||||
tf_dummy_masks = [tf.convert_to_tensor(dummy_masks)]
|
||||
pt_dummy_masks = [torch.tensor(dummy_masks)]
|
||||
|
||||
original_sizes = [[1764, 2646]]
|
||||
|
||||
reshaped_input_size = [[683, 1024]]
|
||||
tf_masks = processor.post_process_masks(
|
||||
tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf"
|
||||
)
|
||||
pt_masks = processor.post_process_masks(
|
||||
pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt"
|
||||
)
|
||||
|
||||
self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy()))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_image_processor_equivalence(self):
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
processor = SamProcessor(image_processor=image_processor)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
pt_input_feat_extract = image_processor(image_input, return_tensors="pt")["pixel_values"].numpy()
|
||||
pt_input_processor = processor(images=image_input, return_tensors="pt")["pixel_values"].numpy()
|
||||
|
||||
tf_input_feat_extract = image_processor(image_input, return_tensors="tf")["pixel_values"].numpy()
|
||||
tf_input_processor = processor(images=image_input, return_tensors="tf")["pixel_values"].numpy()
|
||||
|
||||
self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor))
|
||||
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract))
|
||||
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor))
|
||||
|
||||
Reference in New Issue
Block a user