Fixing Backward compatiblity for zero-shot (#13855)

Fixes #13846
This commit is contained in:
Nicolas Patry
2021-10-06 05:06:47 +02:00
committed by GitHub
parent 9f58becc8d
commit 013bdc6d65
2 changed files with 21 additions and 4 deletions

View File

@@ -61,6 +61,24 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
)
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
# https://github.com/huggingface/transformers/issues/13846
outputs = classifier(["I am happy"], ["positive", "negative"])
self.assertEqual(
outputs,
[
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
for i in range(1)
],
)
outputs = classifier(["I am happy", "I am sad"], ["positive", "negative"])
self.assertEqual(
outputs,
[
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
for i in range(2)
],
)
with self.assertRaises(ValueError):
classifier("", candidate_labels="politics")