From 42baa58f904b430ff1f93be207bf66e8daa096b5 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 23 May 2023 17:36:49 +0200 Subject: [PATCH] =?UTF-8?q?[`SAM`]=C2=A0Fixes=20pipeline=20and=20adds=20a?= =?UTF-8?q?=20dummy=20pipeline=20test=20(#23684)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add a dummy pipeline test * change test name --- src/transformers/models/sam/image_processing_sam.py | 2 +- tests/models/sam/test_modeling_sam.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 64f3bae222..821b43624d 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -934,7 +934,7 @@ def _generate_crop_boxes( cropped_images, point_grid_per_crop = _generate_crop_images( crop_boxes, image, points_grid, layer_idxs, target_size, original_size ) - + crop_boxes = np.array(crop_boxes) crop_boxes = crop_boxes.astype(np.float32) points_per_crop = np.array([point_grid_per_crop]) points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3)) diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index 2342e8010b..a701514522 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -20,7 +20,7 @@ import unittest import requests -from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig +from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline from transformers.testing_utils import require_torch, slow, torch_device from transformers.utils import is_torch_available, is_vision_available @@ -751,3 +751,9 @@ class SamModelIntegrationTest(unittest.TestCase): iou_scores = outputs.iou_scores.cpu() self.assertTrue(iou_scores.shape == (1, 3, 3)) torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) + + def test_dummy_pipeline_generation(self): + generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device) + raw_image = prepare_image() + + _ = generator(raw_image, points_per_batch=64)