Update expected slices for pillow > 9 (#16117)

* Update expected slices for pillow > 9

* Add expected slices depending on pillow version

* Add different slices depending on pillow version for other models

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-03-18 09:46:45 +01:00
committed by GitHub
parent 12d1f07770
commit ec4e421b7d
3 changed files with 51 additions and 11 deletions

View File

@@ -17,6 +17,7 @@
import unittest
from datasets import load_dataset
from packaging import version
from transformers import ViltConfig, is_torch_available, is_vision_available
from transformers.file_utils import cached_property
@@ -41,6 +42,7 @@ if is_torch_available():
from transformers.models.vilt.modeling_vilt import VILT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
import PIL
from PIL import Image
from transformers import ViltProcessor
@@ -603,5 +605,17 @@ class ViltModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size([1, 2])
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-2.4013, 2.9342]).to(torch_device)
is_pillow_less_than_9 = version.parse(PIL.__version__) < version.parse("9.0.0")
if is_pillow_less_than_9:
expected_slice = torch.tensor(
[-2.4013, 2.9342],
device=torch_device,
)
else:
expected_slice = torch.tensor(
[-2.3713, 2.9168],
device=torch_device,
)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))