From 8aad4363d8963a6c6b32b8d2b3f0553ebd8d1b0a Mon Sep 17 00:00:00 2001 From: Sivaudha Date: Mon, 17 Oct 2022 15:06:20 +0200 Subject: [PATCH] Fix pipeline predict transform methods (#19657) * Remove key word argument X from pipeline predict and transform methods As __call__ of pipeline clasees require one positional argument, passing the input as a keyword argument inside predict, transform methods, causing __call__ to fail. Hence in this commit the keyword argument is modified into positional argument. * Implement basic tests for scikitcompat pipeline interface * Seperate tests instead of running with parameterized based on framework as both frameworks will not be active at the same time --- src/transformers/pipelines/base.py | 4 +- tests/pipelines/test_pipelines_common.py | 50 ++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index b5e7c9cb58..4205ef2eb2 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -836,13 +836,13 @@ class Pipeline(_ScikitCompat): """ Scikit / Keras interface to transformers' pipelines. This method will forward to __call__(). """ - return self(X=X) + return self(X) def predict(self, X): """ Scikit / Keras interface to transformers' pipelines. This method will forward to __call__(). """ - return self(X=X) + return self(X) @contextmanager def device_placement(self): diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 9449b38094..314b183663 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -423,6 +423,56 @@ class CommonPipelineTest(unittest.TestCase): self.assertEqual(len(outputs), 20) +class PipelineScikitCompatTest(unittest.TestCase): + @require_torch + def test_pipeline_predict_pt(self): + data = ["This is a test"] + + text_classifier = pipeline( + task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt" + ) + + expected_output = [{"label": ANY(str), "score": ANY(float)}] + actual_output = text_classifier.predict(data) + self.assertEqual(expected_output, actual_output) + + @require_tf + def test_pipeline_predict_tf(self): + data = ["This is a test"] + + text_classifier = pipeline( + task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="tf" + ) + + expected_output = [{"label": ANY(str), "score": ANY(float)}] + actual_output = text_classifier.predict(data) + self.assertEqual(expected_output, actual_output) + + @require_torch + def test_pipeline_transform_pt(self): + data = ["This is a test"] + + text_classifier = pipeline( + task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt" + ) + + expected_output = [{"label": ANY(str), "score": ANY(float)}] + actual_output = text_classifier.transform(data) + self.assertEqual(expected_output, actual_output) + + @require_tf + def test_pipeline_transform_tf(self): + data = ["This is a test"] + + text_classifier = pipeline( + task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="tf" + ) + + expected_output = [{"label": ANY(str), "score": ANY(float)}] + actual_output = text_classifier.transform(data) + self.assertEqual(expected_output, actual_output) + + class PipelinePadTest(unittest.TestCase): @require_torch def test_pipeline_padding(self):