Update expected slices for pillow > 9 (#16117)
* Update expected slices for pillow > 9 * Add expected slices depending on pillow version * Add different slices depending on pillow version for other models Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -18,6 +18,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
from packaging import version
|
||||
|
||||
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
@@ -51,6 +52,7 @@ if is_torch_available():
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from transformers import TrOCRProcessor, ViTFeatureExtractor
|
||||
@@ -687,9 +689,18 @@ class TrOCRModelIntegrationTest(unittest.TestCase):
|
||||
expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[-5.6816, -5.8388, 1.1398, -6.9034, 6.8505, -2.4393, 1.2284, -1.0232, -1.9661, -3.9210]
|
||||
).to(torch_device)
|
||||
is_pillow_less_than_9 = version.parse(PIL.__version__) < version.parse("9.0.0")
|
||||
|
||||
if is_pillow_less_than_9:
|
||||
expected_slice = torch.tensor(
|
||||
[-5.6816, -5.8388, 1.1398, -6.9034, 6.8505, -2.4393, 1.2284, -1.0232, -1.9661, -3.9210],
|
||||
device=torch_device,
|
||||
)
|
||||
else:
|
||||
expected_slice = torch.tensor(
|
||||
[-5.6844, -5.8372, 1.1518, -6.8984, 6.8587, -2.4453, 1.2347, -1.0241, -1.9649, -3.9109],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user