Fixing zero-shot backward compatiblity (#13725)

Fixes #13697
This commit is contained in:
Nicolas Patry
2021-09-24 13:38:17 +02:00
committed by GitHub
parent a2ef9c5446
commit 0eabe49204
2 changed files with 12 additions and 0 deletions

View File

@@ -150,6 +150,7 @@ class ZeroShotClassificationPipeline(Pipeline):
def __call__( def __call__(
self, self,
sequences: Union[str, List[str]], sequences: Union[str, List[str]],
*args,
**kwargs, **kwargs,
): ):
""" """
@@ -183,6 +184,13 @@ class ZeroShotClassificationPipeline(Pipeline):
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels. - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
""" """
if len(args) == 0:
pass
elif len(args) == 1 and "candidate_labels" not in kwargs:
kwargs["candidate_labels"] = args[0]
else:
raise ValueError(f"Unable to understand extra arguments {args}")
result = super().__call__(sequences, **kwargs) result = super().__call__(sequences, **kwargs)
if len(result) == 1: if len(result) == 1:
return result[0] return result[0]

View File

@@ -37,6 +37,10 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics") outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics")
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
# No kwarg
outputs = classifier("Who are you voting for in 2020?", ["politics"])
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"]) outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"])
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})