From 3d3c7d4213e08d69254edb9c04ac28b3dfbd40f4 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 17 May 2023 14:27:43 +0200 Subject: [PATCH] [`SAM`] fix sam slow test (#23376) * fix sam slow test * oops * fix error message --- src/transformers/models/sam/processing_sam.py | 4 ++-- tests/models/sam/test_modeling_sam.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index b5ae51d7db..919ed685a8 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -208,7 +208,7 @@ class SamProcessor(ProcessorMixin): input_points = input_points.numpy().tolist() if not isinstance(input_points, list) or not isinstance(input_points[0], list): - raise ValueError("Input points must be a list of list of floating integers.") + raise ValueError("Input points must be a list of list of floating points.") input_points = [np.array(input_point) for input_point in input_points] else: input_points = None @@ -232,7 +232,7 @@ class SamProcessor(ProcessorMixin): or not isinstance(input_boxes[0], list) or not isinstance(input_boxes[0][0], list) ): - raise ValueError("Input boxes must be a list of list of list of floating integers.") + raise ValueError("Input boxes must be a list of list of list of floating points.") input_boxes = [np.array(box).astype(np.float32) for box in input_boxes] else: input_boxes = None diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index e51eb07dd3..a45e5c4bf5 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -481,7 +481,7 @@ class SamModelIntegrationTest(unittest.TestCase): model.eval() raw_image = prepare_image() - input_boxes = [[650, 900, 1000, 1250]] + input_boxes = [[[650, 900, 1000, 1250]]] input_points = [[[820, 1080]]] inputs = processor( @@ -541,7 +541,7 @@ class SamModelIntegrationTest(unittest.TestCase): model.eval() raw_image = prepare_image() - input_boxes = [[620, 900, 1000, 1255]] + input_boxes = [[[620, 900, 1000, 1255]]] input_points = [[[820, 1080]]] labels = [[0]]