The tests were not updated after the addition of torch.diag (#15890)
in the scoring (which is more correct)
This commit is contained in:
@@ -186,9 +186,9 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(output),
|
nested_simplify(output),
|
||||||
[
|
[
|
||||||
{"score": 0.941, "label": "cat"},
|
{"score": 0.511, "label": "remote"},
|
||||||
{"score": 0.055, "label": "remote"},
|
{"score": 0.485, "label": "cat"},
|
||||||
{"score": 0.003, "label": "plane"},
|
{"score": 0.004, "label": "plane"},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -197,9 +197,9 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe
|
|||||||
nested_simplify(output),
|
nested_simplify(output),
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
{"score": 0.941, "label": "cat"},
|
{"score": 0.511, "label": "remote"},
|
||||||
{"score": 0.055, "label": "remote"},
|
{"score": 0.485, "label": "cat"},
|
||||||
{"score": 0.003, "label": "plane"},
|
{"score": 0.004, "label": "plane"},
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
* 5,
|
* 5,
|
||||||
@@ -214,13 +214,12 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe
|
|||||||
# This is an image of 2 cats with remotes and no planes
|
# This is an image of 2 cats with remotes and no planes
|
||||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
output = image_classifier(image, candidate_labels=["cat", "plane", "remote"])
|
output = image_classifier(image, candidate_labels=["cat", "plane", "remote"])
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(output),
|
nested_simplify(output),
|
||||||
[
|
[
|
||||||
{"score": 0.941, "label": "cat"},
|
{"score": 0.511, "label": "remote"},
|
||||||
{"score": 0.055, "label": "remote"},
|
{"score": 0.485, "label": "cat"},
|
||||||
{"score": 0.003, "label": "plane"},
|
{"score": 0.004, "label": "plane"},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -229,9 +228,9 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe
|
|||||||
nested_simplify(output),
|
nested_simplify(output),
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
{"score": 0.941, "label": "cat"},
|
{"score": 0.511, "label": "remote"},
|
||||||
{"score": 0.055, "label": "remote"},
|
{"score": 0.485, "label": "cat"},
|
||||||
{"score": 0.003, "label": "plane"},
|
{"score": 0.004, "label": "plane"},
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
* 5,
|
* 5,
|
||||||
|
|||||||
Reference in New Issue
Block a user