@@ -61,6 +61,24 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
|
||||
)
|
||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||
|
||||
# https://github.com/huggingface/transformers/issues/13846
|
||||
outputs = classifier(["I am happy"], ["positive", "negative"])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
|
||||
for i in range(1)
|
||||
],
|
||||
)
|
||||
outputs = classifier(["I am happy", "I am sad"], ["positive", "negative"])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
|
||||
for i in range(2)
|
||||
],
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
classifier("", candidate_labels="politics")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user