From b693cbf99c5a180dde8b32ded2fb82ea735aab15 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 3 Mar 2022 15:33:49 +0100 Subject: [PATCH] The tests were not updated after the addition of `torch.diag` (#15890) in the scoring (which is more correct) --- ...ipelines_zero_shot_image_classification.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/pipelines/test_pipelines_zero_shot_image_classification.py b/tests/pipelines/test_pipelines_zero_shot_image_classification.py index c314b92a0b..a5aef5c35b 100644 --- a/tests/pipelines/test_pipelines_zero_shot_image_classification.py +++ b/tests/pipelines/test_pipelines_zero_shot_image_classification.py @@ -186,9 +186,9 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe self.assertEqual( nested_simplify(output), [ - {"score": 0.941, "label": "cat"}, - {"score": 0.055, "label": "remote"}, - {"score": 0.003, "label": "plane"}, + {"score": 0.511, "label": "remote"}, + {"score": 0.485, "label": "cat"}, + {"score": 0.004, "label": "plane"}, ], ) @@ -197,9 +197,9 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe nested_simplify(output), [ [ - {"score": 0.941, "label": "cat"}, - {"score": 0.055, "label": "remote"}, - {"score": 0.003, "label": "plane"}, + {"score": 0.511, "label": "remote"}, + {"score": 0.485, "label": "cat"}, + {"score": 0.004, "label": "plane"}, ], ] * 5, @@ -214,13 +214,12 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe # This is an image of 2 cats with remotes and no planes image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") output = image_classifier(image, candidate_labels=["cat", "plane", "remote"]) - self.assertEqual( nested_simplify(output), [ - {"score": 0.941, "label": "cat"}, - {"score": 0.055, "label": "remote"}, - {"score": 0.003, "label": "plane"}, + {"score": 0.511, "label": "remote"}, + {"score": 0.485, "label": "cat"}, + {"score": 0.004, "label": "plane"}, ], ) @@ -229,9 +228,9 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe nested_simplify(output), [ [ - {"score": 0.941, "label": "cat"}, - {"score": 0.055, "label": "remote"}, - {"score": 0.003, "label": "plane"}, + {"score": 0.511, "label": "remote"}, + {"score": 0.485, "label": "cat"}, + {"score": 0.004, "label": "plane"}, ], ] * 5,