accelerate support for OwlViT (#20411)
* `accelerate` support for `OwlViT` - added `accelerate` support - added slow `fp16` tests * apply suggestions
This commit is contained in:
@@ -24,7 +24,7 @@ import numpy as np
|
||||
|
||||
import requests
|
||||
from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -778,3 +778,28 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_inference_one_shot_object_detection_fp16(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
model = OwlViTForObjectDetection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
image = prepare_img()
|
||||
query_image = prepare_img()
|
||||
inputs = processor(
|
||||
images=image,
|
||||
query_images=query_image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.image_guided_detection(**inputs)
|
||||
|
||||
# No need to check the logits, we just check inference runs fine.
|
||||
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
Reference in New Issue
Block a user