accelerate support for OwlViT (#20411)

* `accelerate` support for `OwlViT`

- added `accelerate` support
- added slow `fp16` tests

* apply suggestions
This commit is contained in:
Younes Belkada
2022-11-25 11:20:44 +01:00
committed by GitHub
parent afce73bd9d
commit a1d4563f7a
2 changed files with 40 additions and 4 deletions

View File

@@ -24,7 +24,7 @@ import numpy as np
import requests
from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -778,3 +778,28 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
@slow
@require_torch_gpu
def test_inference_one_shot_object_detection_fp16(self):
model_name = "google/owlvit-base-patch32"
model = OwlViTForObjectDetection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
processor = OwlViTProcessor.from_pretrained(model_name)
image = prepare_img()
query_image = prepare_img()
inputs = processor(
images=image,
query_images=query_image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs)
# No need to check the logits, we just check inference runs fine.
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))