From 62fe753325f29532a329bd6e0d4c6542ca9e4ef2 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 9 Jun 2023 16:22:09 +0200 Subject: [PATCH] [`SAM`] Fix sam slow test (#24140) * fix sam test * update pipeline typehint --- src/transformers/pipelines/base.py | 2 +- tests/models/sam/test_modeling_sam.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 466ce81999..510c07cf5f 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -760,7 +760,7 @@ class Pipeline(_ScikitCompat): framework: Optional[str] = None, task: str = "", args_parser: ArgumentHandler = None, - device: Union[int, str, "torch.device"] = None, + device: Union[int, "torch.device"] = None, torch_dtype: Optional[Union[str, "torch.dtype"]] = None, binary_output: bool = False, **kwargs, diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index c8bc30a298..a0f39a4013 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -760,7 +760,9 @@ class SamModelIntegrationTest(unittest.TestCase): torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) def test_dummy_pipeline_generation(self): - generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device) + generator = pipeline( + "mask-generation", model="facebook/sam-vit-base", device=0 if torch.cuda.is_available() else -1 + ) raw_image = prepare_image() _ = generator(raw_image, points_per_batch=64)