🚨🚨🚨 Fix rescale ViVit Efficientnet (#25174)
* Fix rescaling bug * Add tests * Update integration tests * Fix up * Update src/transformers/image_transforms.py * Update test - new possible order in list
This commit is contained in:
@@ -193,3 +193,17 @@ class EfficientNetImageProcessorTest(ImageProcessingSavingTestMixin, unittest.Te
|
||||
self.image_processor_tester.size["width"],
|
||||
),
|
||||
)
|
||||
|
||||
def test_rescale(self):
|
||||
# EfficientNet optionally rescales between -1 and 1 instead of the usual 0 and 1
|
||||
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)
|
||||
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255)
|
||||
expected_image = image.astype(np.float32) * (2 / 255.0) - 1
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
|
||||
expected_image = image.astype(np.float32) / 255.0
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
|
||||
@@ -212,3 +212,17 @@ class VivitImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
|
||||
def test_rescale(self):
|
||||
# ViVit optionally rescales between -1 and 1 instead of the usual 0 and 1
|
||||
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)
|
||||
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255)
|
||||
expected_image = image.astype(np.float32) * (2 / 255.0) - 1
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
|
||||
expected_image = image.astype(np.float32) / 255.0
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
|
||||
@@ -345,6 +345,6 @@ class VivitModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
# taken from original model
|
||||
expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658]).to(torch_device)
|
||||
expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user