Fix default behaviour in TextClassificationPipeline for regression problem type (#34066)
* update code * update docstrings * update tests
This commit is contained in:
@@ -40,7 +40,8 @@ class ClassificationFunction(ExplicitEnum):
|
|||||||
The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
|
The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
|
||||||
|
|
||||||
- `"default"`: if the model has a single label, will apply the sigmoid function on the output. If the model
|
- `"default"`: if the model has a single label, will apply the sigmoid function on the output. If the model
|
||||||
has several labels, will apply the softmax function on the output.
|
has several labels, will apply the softmax function on the output. In case of regression tasks, will not
|
||||||
|
apply any function on the output.
|
||||||
- `"sigmoid"`: Applies the sigmoid function on the output.
|
- `"sigmoid"`: Applies the sigmoid function on the output.
|
||||||
- `"softmax"`: Applies the softmax function on the output.
|
- `"softmax"`: Applies the softmax function on the output.
|
||||||
- `"none"`: Does not apply any function on the output.""",
|
- `"none"`: Does not apply any function on the output.""",
|
||||||
@@ -69,7 +70,8 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
`"sentiment-analysis"` (for classifying sequences according to positive or negative sentiments).
|
`"sentiment-analysis"` (for classifying sequences according to positive or negative sentiments).
|
||||||
|
|
||||||
If multiple classification labels are available (`model.config.num_labels >= 2`), the pipeline will run a softmax
|
If multiple classification labels are available (`model.config.num_labels >= 2`), the pipeline will run a softmax
|
||||||
over the results. If there is a single label, the pipeline will run a sigmoid over the result.
|
over the results. If there is a single label, the pipeline will run a sigmoid over the result. In case of regression
|
||||||
|
tasks (`model.config.problem_type == "regression"`), will not apply any function on the output.
|
||||||
|
|
||||||
The models that this pipeline can use are models that have been fine-tuned on a sequence classification task. See
|
The models that this pipeline can use are models that have been fine-tuned on a sequence classification task. See
|
||||||
the up-to-date list of available models on
|
the up-to-date list of available models on
|
||||||
@@ -135,6 +137,7 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
If this argument is not specified, then it will apply the following functions according to the number
|
If this argument is not specified, then it will apply the following functions according to the number
|
||||||
of labels:
|
of labels:
|
||||||
|
|
||||||
|
- If problem type is regression, will not apply any function on the output.
|
||||||
- If the model has a single label, will apply the sigmoid function on the output.
|
- If the model has a single label, will apply the sigmoid function on the output.
|
||||||
- If the model has several labels, will apply the softmax function on the output.
|
- If the model has several labels, will apply the softmax function on the output.
|
||||||
|
|
||||||
@@ -192,7 +195,9 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
# the more natural result containing the list.
|
# the more natural result containing the list.
|
||||||
# Default value before `set_parameters`
|
# Default value before `set_parameters`
|
||||||
if function_to_apply is None:
|
if function_to_apply is None:
|
||||||
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1:
|
if self.model.config.problem_type == "regression":
|
||||||
|
function_to_apply = ClassificationFunction.NONE
|
||||||
|
elif self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1:
|
||||||
function_to_apply = ClassificationFunction.SIGMOID
|
function_to_apply = ClassificationFunction.SIGMOID
|
||||||
elif self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels > 1:
|
elif self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels > 1:
|
||||||
function_to_apply = ClassificationFunction.SOFTMAX
|
function_to_apply = ClassificationFunction.SOFTMAX
|
||||||
|
|||||||
@@ -108,6 +108,12 @@ class TextClassificationPipelineTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Do not apply any function to output for regression tasks
|
||||||
|
# hack: changing problem_type artifically (so keep this test at last)
|
||||||
|
text_classifier.model.config.problem_type = "regression"
|
||||||
|
outputs = text_classifier("This is great !")
|
||||||
|
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.01}])
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_accepts_torch_device(self):
|
def test_accepts_torch_device(self):
|
||||||
text_classifier = pipeline(
|
text_classifier = pipeline(
|
||||||
|
|||||||
Reference in New Issue
Block a user