The tests were not updated after the addition of torch.diag (#15890)

in the scoring (which is more correct)
This commit is contained in:
Nicolas Patry
2022-03-03 15:33:49 +01:00
committed by GitHub
parent 3c4fbc616f
commit b693cbf99c

View File

@@ -186,9 +186,9 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe
self.assertEqual(
nested_simplify(output),
[
{"score": 0.941, "label": "cat"},
{"score": 0.055, "label": "remote"},
{"score": 0.003, "label": "plane"},
{"score": 0.511, "label": "remote"},
{"score": 0.485, "label": "cat"},
{"score": 0.004, "label": "plane"},
],
)
@@ -197,9 +197,9 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe
nested_simplify(output),
[
[
{"score": 0.941, "label": "cat"},
{"score": 0.055, "label": "remote"},
{"score": 0.003, "label": "plane"},
{"score": 0.511, "label": "remote"},
{"score": 0.485, "label": "cat"},
{"score": 0.004, "label": "plane"},
],
]
* 5,
@@ -214,13 +214,12 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe
# This is an image of 2 cats with remotes and no planes
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
output = image_classifier(image, candidate_labels=["cat", "plane", "remote"])
self.assertEqual(
nested_simplify(output),
[
{"score": 0.941, "label": "cat"},
{"score": 0.055, "label": "remote"},
{"score": 0.003, "label": "plane"},
{"score": 0.511, "label": "remote"},
{"score": 0.485, "label": "cat"},
{"score": 0.004, "label": "plane"},
],
)
@@ -229,9 +228,9 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe
nested_simplify(output),
[
[
{"score": 0.941, "label": "cat"},
{"score": 0.055, "label": "remote"},
{"score": 0.003, "label": "plane"},
{"score": 0.511, "label": "remote"},
{"score": 0.485, "label": "cat"},
{"score": 0.004, "label": "plane"},
],
]
* 5,