[SAM] Fixes pipeline and adds a dummy pipeline test (#23684)
* add a dummy pipeline test * change test name
This commit is contained in:
@@ -934,7 +934,7 @@ def _generate_crop_boxes(
|
||||
cropped_images, point_grid_per_crop = _generate_crop_images(
|
||||
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
||||
)
|
||||
|
||||
crop_boxes = np.array(crop_boxes)
|
||||
crop_boxes = crop_boxes.astype(np.float32)
|
||||
points_per_crop = np.array([point_grid_per_crop])
|
||||
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
|
||||
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
||||
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
@@ -751,3 +751,9 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
iou_scores = outputs.iou_scores.cpu()
|
||||
self.assertTrue(iou_scores.shape == (1, 3, 3))
|
||||
torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_dummy_pipeline_generation(self):
|
||||
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
|
||||
raw_image = prepare_image()
|
||||
|
||||
_ = generator(raw_image, points_per_batch=64)
|
||||
|
||||
Reference in New Issue
Block a user