From 0eabe49204e4fb1cf393283ee7c16706b42a9abb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 24 Sep 2021 13:38:17 +0200 Subject: [PATCH] Fixing zero-shot backward compatiblity (#13725) Fixes #13697 --- src/transformers/pipelines/zero_shot_classification.py | 8 ++++++++ tests/test_pipelines_zero_shot.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index d374b952c7..1e2b88d7be 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -150,6 +150,7 @@ class ZeroShotClassificationPipeline(Pipeline): def __call__( self, sequences: Union[str, List[str]], + *args, **kwargs, ): """ @@ -183,6 +184,13 @@ class ZeroShotClassificationPipeline(Pipeline): - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels. """ + if len(args) == 0: + pass + elif len(args) == 1 and "candidate_labels" not in kwargs: + kwargs["candidate_labels"] = args[0] + else: + raise ValueError(f"Unable to understand extra arguments {args}") + result = super().__call__(sequences, **kwargs) if len(result) == 1: return result[0] diff --git a/tests/test_pipelines_zero_shot.py b/tests/test_pipelines_zero_shot.py index 69fd65f71d..d22ce68621 100644 --- a/tests/test_pipelines_zero_shot.py +++ b/tests/test_pipelines_zero_shot.py @@ -37,6 +37,10 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics") self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) + # No kwarg + outputs = classifier("Who are you voting for in 2020?", ["politics"]) + self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]}) + outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"]) self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})