Fix pipeline predict transform methods (#19657)
* Remove key word argument X from pipeline predict and transform methods As __call__ of pipeline clasees require one positional argument, passing the input as a keyword argument inside predict, transform methods, causing __call__ to fail. Hence in this commit the keyword argument is modified into positional argument. * Implement basic tests for scikitcompat pipeline interface * Seperate tests instead of running with parameterized based on framework as both frameworks will not be active at the same time
This commit is contained in:
@@ -836,13 +836,13 @@ class Pipeline(_ScikitCompat):
|
|||||||
"""
|
"""
|
||||||
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
|
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
|
||||||
"""
|
"""
|
||||||
return self(X=X)
|
return self(X)
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
"""
|
"""
|
||||||
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
|
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
|
||||||
"""
|
"""
|
||||||
return self(X=X)
|
return self(X)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def device_placement(self):
|
def device_placement(self):
|
||||||
|
|||||||
@@ -423,6 +423,56 @@ class CommonPipelineTest(unittest.TestCase):
|
|||||||
self.assertEqual(len(outputs), 20)
|
self.assertEqual(len(outputs), 20)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineScikitCompatTest(unittest.TestCase):
|
||||||
|
@require_torch
|
||||||
|
def test_pipeline_predict_pt(self):
|
||||||
|
data = ["This is a test"]
|
||||||
|
|
||||||
|
text_classifier = pipeline(
|
||||||
|
task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_output = [{"label": ANY(str), "score": ANY(float)}]
|
||||||
|
actual_output = text_classifier.predict(data)
|
||||||
|
self.assertEqual(expected_output, actual_output)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_pipeline_predict_tf(self):
|
||||||
|
data = ["This is a test"]
|
||||||
|
|
||||||
|
text_classifier = pipeline(
|
||||||
|
task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="tf"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_output = [{"label": ANY(str), "score": ANY(float)}]
|
||||||
|
actual_output = text_classifier.predict(data)
|
||||||
|
self.assertEqual(expected_output, actual_output)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_pipeline_transform_pt(self):
|
||||||
|
data = ["This is a test"]
|
||||||
|
|
||||||
|
text_classifier = pipeline(
|
||||||
|
task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_output = [{"label": ANY(str), "score": ANY(float)}]
|
||||||
|
actual_output = text_classifier.transform(data)
|
||||||
|
self.assertEqual(expected_output, actual_output)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_pipeline_transform_tf(self):
|
||||||
|
data = ["This is a test"]
|
||||||
|
|
||||||
|
text_classifier = pipeline(
|
||||||
|
task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="tf"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_output = [{"label": ANY(str), "score": ANY(float)}]
|
||||||
|
actual_output = text_classifier.transform(data)
|
||||||
|
self.assertEqual(expected_output, actual_output)
|
||||||
|
|
||||||
|
|
||||||
class PipelinePadTest(unittest.TestCase):
|
class PipelinePadTest(unittest.TestCase):
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pipeline_padding(self):
|
def test_pipeline_padding(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user