@@ -150,6 +150,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
sequences: Union[str, List[str]],
|
sequences: Union[str, List[str]],
|
||||||
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -183,6 +184,13 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
|
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if len(args) == 0:
|
||||||
|
pass
|
||||||
|
elif len(args) == 1 and "candidate_labels" not in kwargs:
|
||||||
|
kwargs["candidate_labels"] = args[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unable to understand extra arguments {args}")
|
||||||
|
|
||||||
result = super().__call__(sequences, **kwargs)
|
result = super().__call__(sequences, **kwargs)
|
||||||
if len(result) == 1:
|
if len(result) == 1:
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|||||||
@@ -37,6 +37,10 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
|
|||||||
outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics")
|
outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics")
|
||||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||||
|
|
||||||
|
# No kwarg
|
||||||
|
outputs = classifier("Who are you voting for in 2020?", ["politics"])
|
||||||
|
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||||
|
|
||||||
outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"])
|
outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"])
|
||||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user