From e6122c3f40d3c1fec5c4966a58340fd62d55cb71 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 15 Jun 2023 13:09:31 +0100 Subject: [PATCH] Fix image segmentation tool bug (#23897) * Image segmentation tool bug * Remove resizing in the tests --- src/transformers/tools/image_segmentation.py | 1 - tests/tools/test_image_segmentation.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/tools/image_segmentation.py b/src/transformers/tools/image_segmentation.py index 4471b84905..b6cbf3eb3f 100644 --- a/src/transformers/tools/image_segmentation.py +++ b/src/transformers/tools/image_segmentation.py @@ -44,7 +44,6 @@ class ImageSegmentationTool(PipelineTool): super().__init__(*args, **kwargs) def encode(self, image: "Image", label: str): - self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]} return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt") def forward(self, inputs): diff --git a/tests/tools/test_image_segmentation.py b/tests/tools/test_image_segmentation.py index 8a741b501b..2f003f2c8b 100644 --- a/tests/tools/test_image_segmentation.py +++ b/tests/tools/test_image_segmentation.py @@ -33,21 +33,21 @@ class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin): self.remote_tool = load_tool("image-segmentation", remote=True) def test_exact_match_arg(self): - image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512)) + image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png") result = self.tool(image, "cat") self.assertTrue(isinstance(result, Image.Image)) def test_exact_match_arg_remote(self): - image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512)) + image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png") result = self.remote_tool(image, "cat") self.assertTrue(isinstance(result, Image.Image)) def test_exact_match_kwarg(self): - image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512)) + image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png") result = self.tool(image=image, label="cat") self.assertTrue(isinstance(result, Image.Image)) def test_exact_match_kwarg_remote(self): - image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512)) + image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png") result = self.remote_tool(image=image, label="cat") self.assertTrue(isinstance(result, Image.Image))