Fix a couple of typos and add an illustrative test (#26941)
* fix a typo and add an illustrative test * appease black * reduce code duplication and add Annotion type back with a pending deprecation warning * remove unused code * change warning type * black formatting fix * change enum deprecation approach to support 3.8 and earlier * add stacklevel * fix black issue * fix ruff issues * fix ruff issues * move tests to own mixin * include yolos * fix black formatting issue * fix black formatting issue * use logger instead of warnings and include target version for deprecation
This commit is contained in:
@@ -21,7 +21,7 @@ import unittest
|
||||
from transformers.testing_utils import require_torch, require_vision, slow
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
from ...test_image_processing_common import AnnotationFormatTestMixin, ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -127,7 +127,7 @@ class DetrImageProcessingTester(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class DetrImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
class DetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = DetrImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
@@ -159,6 +159,63 @@ class DetrImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84})
|
||||
self.assertEqual(image_processor.do_pad, False)
|
||||
|
||||
def test_should_raise_if_annotation_format_invalid(self):
|
||||
image_processor_dict = self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
|
||||
detection_target = json.loads(f.read())
|
||||
|
||||
annotations = {"image_id": 39769, "annotations": detection_target}
|
||||
|
||||
params = {
|
||||
"images": Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
"annotations": annotations,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
|
||||
image_processor_params = {**image_processor_dict, **{"format": "_INVALID_FORMAT_"}}
|
||||
image_processor = self.image_processing_class(**image_processor_params)
|
||||
|
||||
with self.assertRaises(ValueError) as e:
|
||||
image_processor(**params)
|
||||
|
||||
self.assertTrue(str(e.exception).startswith("_INVALID_FORMAT_ is not a valid AnnotationFormat"))
|
||||
|
||||
def test_valid_coco_detection_annotations(self):
|
||||
# prepare image and target
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
|
||||
target = json.loads(f.read())
|
||||
|
||||
params = {"image_id": 39769, "annotations": target}
|
||||
|
||||
# encode them
|
||||
image_processing = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
||||
|
||||
# legal encodings (single image)
|
||||
_ = image_processing(images=image, annotations=params, return_tensors="pt")
|
||||
_ = image_processing(images=image, annotations=[params], return_tensors="pt")
|
||||
|
||||
# legal encodings (batch of one image)
|
||||
_ = image_processing(images=[image], annotations=params, return_tensors="pt")
|
||||
_ = image_processing(images=[image], annotations=[params], return_tensors="pt")
|
||||
|
||||
# legal encoding (batch of more than one image)
|
||||
n = 5
|
||||
_ = image_processing(images=[image] * n, annotations=[params] * n, return_tensors="pt")
|
||||
|
||||
# example of an illegal encoding (missing the 'image_id' key)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
image_processing(images=image, annotations={"annotations": target}, return_tensors="pt")
|
||||
|
||||
self.assertTrue(str(e.exception).startswith("Invalid COCO detection annotations"))
|
||||
|
||||
# example of an illegal encoding (unequal lengths of images and annotations)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
image_processing(images=[image] * n, annotations=[params] * (n - 1), return_tensors="pt")
|
||||
|
||||
self.assertTrue(str(e.exception) == "The number of images (5) and annotations (4) do not match.")
|
||||
|
||||
@slow
|
||||
def test_call_pytorch_with_coco_detection_annotations(self):
|
||||
# prepare image and target
|
||||
|
||||
Reference in New Issue
Block a user