@@ -20,8 +20,9 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import ImageGPTConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers import ImageGPTConfig
|
||||
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
@@ -38,6 +39,11 @@ if is_torch_available():
|
||||
ImageGPTModel,
|
||||
)
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ImageGPTFeatureExtractor
|
||||
|
||||
|
||||
class ImageGPTModelTester:
|
||||
def __init__(
|
||||
@@ -496,3 +502,38 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCa
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class ImageGPTModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return ImageGPTFeatureExtractor.from_pretrained("openai/imagegpt-small") if is_vision_available() else None
|
||||
|
||||
@slow
|
||||
def test_inference_causal_lm_head(self):
|
||||
model = ImageGPTForCausalLM.from_pretrained("openai/imagegpt-small").to(torch_device)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1024, 512))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[2.3445, 2.6889, 2.7313], [1.0530, 1.2416, 0.5699], [0.2205, 0.7749, 0.3953]]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user