[SAM] Fixes pipeline and adds a dummy pipeline test (#23684)
* add a dummy pipeline test * change test name
This commit is contained in:
@@ -934,7 +934,7 @@ def _generate_crop_boxes(
|
|||||||
cropped_images, point_grid_per_crop = _generate_crop_images(
|
cropped_images, point_grid_per_crop = _generate_crop_images(
|
||||||
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
||||||
)
|
)
|
||||||
|
crop_boxes = np.array(crop_boxes)
|
||||||
crop_boxes = crop_boxes.astype(np.float32)
|
crop_boxes = crop_boxes.astype(np.float32)
|
||||||
points_per_crop = np.array([point_grid_per_crop])
|
points_per_crop = np.array([point_grid_per_crop])
|
||||||
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
|
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import unittest
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
@@ -751,3 +751,9 @@ class SamModelIntegrationTest(unittest.TestCase):
|
|||||||
iou_scores = outputs.iou_scores.cpu()
|
iou_scores = outputs.iou_scores.cpu()
|
||||||
self.assertTrue(iou_scores.shape == (1, 3, 3))
|
self.assertTrue(iou_scores.shape == (1, 3, 3))
|
||||||
torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
|
torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
|
def test_dummy_pipeline_generation(self):
|
||||||
|
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
|
||||||
|
raw_image = prepare_image()
|
||||||
|
|
||||||
|
_ = generator(raw_image, points_per_batch=64)
|
||||||
|
|||||||
Reference in New Issue
Block a user