@@ -150,6 +150,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
def __call__(
|
||||
self,
|
||||
sequences: Union[str, List[str]],
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -183,6 +184,13 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
|
||||
"""
|
||||
|
||||
if len(args) == 0:
|
||||
pass
|
||||
elif len(args) == 1 and "candidate_labels" not in kwargs:
|
||||
kwargs["candidate_labels"] = args[0]
|
||||
else:
|
||||
raise ValueError(f"Unable to understand extra arguments {args}")
|
||||
|
||||
result = super().__call__(sequences, **kwargs)
|
||||
if len(result) == 1:
|
||||
return result[0]
|
||||
|
||||
@@ -37,6 +37,10 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
|
||||
outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics")
|
||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||
|
||||
# No kwarg
|
||||
outputs = classifier("Who are you voting for in 2020?", ["politics"])
|
||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||
|
||||
outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"])
|
||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user