Update expected slices for pillow > 9 (#16117)
* Update expected slices for pillow > 9 * Add expected slices depending on pillow version * Add different slices depending on pillow version for other models Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
from packaging import version
|
||||
|
||||
from transformers import ViltConfig, is_torch_available, is_vision_available
|
||||
from transformers.file_utils import cached_property
|
||||
@@ -41,6 +42,7 @@ if is_torch_available():
|
||||
from transformers.models.vilt.modeling_vilt import VILT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ViltProcessor
|
||||
@@ -603,5 +605,17 @@ class ViltModelIntegrationTest(unittest.TestCase):
|
||||
expected_shape = torch.Size([1, 2])
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([-2.4013, 2.9342]).to(torch_device)
|
||||
is_pillow_less_than_9 = version.parse(PIL.__version__) < version.parse("9.0.0")
|
||||
|
||||
if is_pillow_less_than_9:
|
||||
expected_slice = torch.tensor(
|
||||
[-2.4013, 2.9342],
|
||||
device=torch_device,
|
||||
)
|
||||
else:
|
||||
expected_slice = torch.tensor(
|
||||
[-2.3713, 2.9168],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user