@@ -191,10 +191,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unable to understand extra arguments {args}")
|
raise ValueError(f"Unable to understand extra arguments {args}")
|
||||||
|
|
||||||
result = super().__call__(sequences, **kwargs)
|
return super().__call__(sequences, **kwargs)
|
||||||
if len(result) == 1:
|
|
||||||
return result[0]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
|
def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
|
||||||
sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
|
sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
|
||||||
@@ -264,4 +261,6 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
"scores": scores[iseq, top_inds].tolist(),
|
"scores": scores[iseq, top_inds].tolist(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if len(result) == 1:
|
||||||
|
return result[0]
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -61,6 +61,24 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
|
|||||||
)
|
)
|
||||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||||
|
|
||||||
|
# https://github.com/huggingface/transformers/issues/13846
|
||||||
|
outputs = classifier(["I am happy"], ["positive", "negative"])
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
|
||||||
|
for i in range(1)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
outputs = classifier(["I am happy", "I am sad"], ["positive", "negative"])
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
|
||||||
|
for i in range(2)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
classifier("", candidate_labels="politics")
|
classifier("", candidate_labels="politics")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user