🔥py38 + torch 2 🔥🔥🔥🚀 (#22204)

* py38 + torch 2

* increment cache versions

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-03-16 22:59:23 +01:00
committed by GitHub
parent fb366b9a2a
commit 5110e5748e
6 changed files with 22 additions and 14 deletions

View File

@@ -78,9 +78,14 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase):
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
output = image_classifier(image, candidate_labels=["a", "b", "c"])
self.assertEqual(
# The floating scores are so close, we enter floating error approximation and the order is not guaranteed across
# python and torch versions.
self.assertIn(
nested_simplify(output),
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}],
[
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}],
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}, {"score": 0.333, "label": "b"}],
],
)
output = image_classifier([image] * 5, candidate_labels=["A", "B", "C"], batch_size=2)