Fix pipeline predict transform methods (#19657)
* Remove key word argument X from pipeline predict and transform methods As __call__ of pipeline clasees require one positional argument, passing the input as a keyword argument inside predict, transform methods, causing __call__ to fail. Hence in this commit the keyword argument is modified into positional argument. * Implement basic tests for scikitcompat pipeline interface * Seperate tests instead of running with parameterized based on framework as both frameworks will not be active at the same time
This commit is contained in:
@@ -836,13 +836,13 @@ class Pipeline(_ScikitCompat):
|
||||
"""
|
||||
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
|
||||
"""
|
||||
return self(X=X)
|
||||
return self(X)
|
||||
|
||||
def predict(self, X):
|
||||
"""
|
||||
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
|
||||
"""
|
||||
return self(X=X)
|
||||
return self(X)
|
||||
|
||||
@contextmanager
|
||||
def device_placement(self):
|
||||
|
||||
@@ -423,6 +423,56 @@ class CommonPipelineTest(unittest.TestCase):
|
||||
self.assertEqual(len(outputs), 20)
|
||||
|
||||
|
||||
class PipelineScikitCompatTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_pipeline_predict_pt(self):
|
||||
data = ["This is a test"]
|
||||
|
||||
text_classifier = pipeline(
|
||||
task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt"
|
||||
)
|
||||
|
||||
expected_output = [{"label": ANY(str), "score": ANY(float)}]
|
||||
actual_output = text_classifier.predict(data)
|
||||
self.assertEqual(expected_output, actual_output)
|
||||
|
||||
@require_tf
|
||||
def test_pipeline_predict_tf(self):
|
||||
data = ["This is a test"]
|
||||
|
||||
text_classifier = pipeline(
|
||||
task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="tf"
|
||||
)
|
||||
|
||||
expected_output = [{"label": ANY(str), "score": ANY(float)}]
|
||||
actual_output = text_classifier.predict(data)
|
||||
self.assertEqual(expected_output, actual_output)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_transform_pt(self):
|
||||
data = ["This is a test"]
|
||||
|
||||
text_classifier = pipeline(
|
||||
task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="pt"
|
||||
)
|
||||
|
||||
expected_output = [{"label": ANY(str), "score": ANY(float)}]
|
||||
actual_output = text_classifier.transform(data)
|
||||
self.assertEqual(expected_output, actual_output)
|
||||
|
||||
@require_tf
|
||||
def test_pipeline_transform_tf(self):
|
||||
data = ["This is a test"]
|
||||
|
||||
text_classifier = pipeline(
|
||||
task="text-classification", model="hf-internal-testing/tiny-random-distilbert", framework="tf"
|
||||
)
|
||||
|
||||
expected_output = [{"label": ANY(str), "score": ANY(float)}]
|
||||
actual_output = text_classifier.transform(data)
|
||||
self.assertEqual(expected_output, actual_output)
|
||||
|
||||
|
||||
class PipelinePadTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_pipeline_padding(self):
|
||||
|
||||
Reference in New Issue
Block a user