Fix image post-processing for OWLv2 (#30686)
* feat: add note about owlv2 * fix: post processing coordinates * remove: workaround document * fix: extra quotes * update: owlv2 docstrings * fix: copies check * feat: add unit test for resize * Update tests/models/owlv2/test_image_processor_owlv2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision, slow
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
@@ -25,7 +25,10 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import Owlv2ImageProcessor
|
||||
from transformers import AutoProcessor, Owlv2ForObjectDetection, Owlv2ImageProcessor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class Owlv2ImageProcessingTester(unittest.TestCase):
|
||||
@@ -120,6 +123,25 @@ class Owlv2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
mean_value = round(pixel_values.mean().item(), 4)
|
||||
self.assertEqual(mean_value, 0.2353)
|
||||
|
||||
@slow
|
||||
def test_image_processor_integration_test_resize(self):
|
||||
checkpoint = "google/owlv2-base-patch16-ensemble"
|
||||
processor = AutoProcessor.from_pretrained(checkpoint)
|
||||
model = Owlv2ForObjectDetection.from_pretrained(checkpoint)
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = processor(text=["cat"], images=image, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
target_sizes = torch.tensor([image.size[::-1]])
|
||||
results = processor.post_process_object_detection(outputs, threshold=0.2, target_sizes=target_sizes)[0]
|
||||
|
||||
boxes = results["boxes"].tolist()
|
||||
self.assertEqual(boxes[0], [341.66656494140625, 23.38756561279297, 642.321044921875, 371.3482971191406])
|
||||
self.assertEqual(boxes[1], [6.753320693969727, 51.96149826049805, 326.61810302734375, 473.12982177734375])
|
||||
|
||||
@unittest.skip("OWLv2 doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user