[SAM] Fixes pipeline and adds a dummy pipeline test (#23684)

* add a dummy pipeline test

* change test name
This commit is contained in:
Younes Belkada
2023-05-23 17:36:49 +02:00
committed by GitHub
parent 71a5ed3433
commit 42baa58f90
2 changed files with 8 additions and 2 deletions

View File

@@ -20,7 +20,7 @@ import unittest
import requests
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
@@ -751,3 +751,9 @@ class SamModelIntegrationTest(unittest.TestCase):
iou_scores = outputs.iou_scores.cpu()
self.assertTrue(iou_scores.shape == (1, 3, 3))
torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
def test_dummy_pipeline_generation(self):
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
raw_image = prepare_image()
_ = generator(raw_image, points_per_batch=64)