Fix default behaviour in TextClassificationPipeline for regression problem type (#34066)

* update code

* update docstrings

* update tests
This commit is contained in:
Subhalingam D
2024-10-15 17:36:20 +05:30
committed by GitHub
parent 4de1bdbf63
commit 5ee9e786d1
2 changed files with 14 additions and 3 deletions

View File

@@ -108,6 +108,12 @@ class TextClassificationPipelineTests(unittest.TestCase):
],
)
# Do not apply any function to output for regression tasks
# hack: changing problem_type artifically (so keep this test at last)
text_classifier.model.config.problem_type = "regression"
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.01}])
@require_torch
def test_accepts_torch_device(self):
text_classifier = pipeline(