Adding top_k argument to text-classification pipeline. (#17606)
* Adding `top_k` and `sort` arguments to `text-classification` pipeline. - Deprecate `return_all_scores` as `top_k` is more uniform with other pipelines, and a superset of what `return_all_scores` can do. BC is maintained though. `return_all_scores=True` -> `top_k=None` `return_all_scores=False` -> `top_k=1` - Using `top_k` will imply sorting the results, but using no argument will keep the results unsorted for backward compatibility. * Remove `sort`. * Fixing the test. * Remove bad doc.
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -72,15 +73,26 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||||
)
|
)
|
||||||
|
|
||||||
def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, **tokenizer_kwargs):
|
def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, top_k="", **tokenizer_kwargs):
|
||||||
|
# Using "" as default argument because we're going to use `top_k=None` in user code to declare
|
||||||
|
# "No top_k"
|
||||||
preprocess_params = tokenizer_kwargs
|
preprocess_params = tokenizer_kwargs
|
||||||
|
|
||||||
postprocess_params = {}
|
postprocess_params = {}
|
||||||
if hasattr(self.model.config, "return_all_scores") and return_all_scores is None:
|
if hasattr(self.model.config, "return_all_scores") and return_all_scores is None:
|
||||||
return_all_scores = self.model.config.return_all_scores
|
return_all_scores = self.model.config.return_all_scores
|
||||||
|
|
||||||
if return_all_scores is not None:
|
if isinstance(top_k, int) or top_k is None:
|
||||||
postprocess_params["return_all_scores"] = return_all_scores
|
postprocess_params["top_k"] = top_k
|
||||||
|
postprocess_params["_legacy"] = False
|
||||||
|
elif return_all_scores is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"`return_all_scores` is now deprecated, use `top_k=1` if you want similar functionnality", UserWarning
|
||||||
|
)
|
||||||
|
if return_all_scores:
|
||||||
|
postprocess_params["top_k"] = None
|
||||||
|
else:
|
||||||
|
postprocess_params["top_k"] = 1
|
||||||
|
|
||||||
if isinstance(function_to_apply, str):
|
if isinstance(function_to_apply, str):
|
||||||
function_to_apply = ClassificationFunction[function_to_apply.upper()]
|
function_to_apply = ClassificationFunction[function_to_apply.upper()]
|
||||||
@@ -97,8 +109,8 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`):
|
args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`):
|
||||||
One or several texts to classify. In order to use text pairs for your classification, you can send a
|
One or several texts to classify. In order to use text pairs for your classification, you can send a
|
||||||
dictionnary containing `{"text", "text_pair"}` keys, or a list of those.
|
dictionnary containing `{"text", "text_pair"}` keys, or a list of those.
|
||||||
return_all_scores (`bool`, *optional*, defaults to `False`):
|
top_k (`int`, *optional*, defaults to `1`):
|
||||||
Whether to return scores for all labels.
|
How many results to return.
|
||||||
function_to_apply (`str`, *optional*, defaults to `"default"`):
|
function_to_apply (`str`, *optional*, defaults to `"default"`):
|
||||||
The function to apply to the model outputs in order to retrieve the scores. Accepts four different
|
The function to apply to the model outputs in order to retrieve the scores. Accepts four different
|
||||||
values:
|
values:
|
||||||
@@ -121,10 +133,10 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
- **label** (`str`) -- The label predicted.
|
- **label** (`str`) -- The label predicted.
|
||||||
- **score** (`float`) -- The corresponding probability.
|
- **score** (`float`) -- The corresponding probability.
|
||||||
|
|
||||||
If `self.return_all_scores=True`, one such dictionary is returned per label.
|
If `top_k` is used, one such dictionary is returned per label.
|
||||||
"""
|
"""
|
||||||
result = super().__call__(*args, **kwargs)
|
result = super().__call__(*args, **kwargs)
|
||||||
if isinstance(args[0], str):
|
if isinstance(args[0], str) and isinstance(result, dict):
|
||||||
# This pipeline is odd, and return a list when single item is run
|
# This pipeline is odd, and return a list when single item is run
|
||||||
return [result]
|
return [result]
|
||||||
else:
|
else:
|
||||||
@@ -150,7 +162,10 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
def _forward(self, model_inputs):
|
def _forward(self, model_inputs):
|
||||||
return self.model(**model_inputs)
|
return self.model(**model_inputs)
|
||||||
|
|
||||||
def postprocess(self, model_outputs, function_to_apply=None, return_all_scores=False):
|
def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True):
|
||||||
|
# `_legacy` is used to determine if we're running the naked pipeline and in backward
|
||||||
|
# compatibility mode, or if running the pipeline with `pipeline(..., top_k=1)` we're running
|
||||||
|
# the more natural result containing the list.
|
||||||
# Default value before `set_parameters`
|
# Default value before `set_parameters`
|
||||||
if function_to_apply is None:
|
if function_to_apply is None:
|
||||||
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1:
|
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1:
|
||||||
@@ -174,7 +189,14 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
|
raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
|
||||||
|
|
||||||
if return_all_scores:
|
if top_k == 1 and _legacy:
|
||||||
return [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)]
|
|
||||||
else:
|
|
||||||
return {"label": self.model.config.id2label[scores.argmax().item()], "score": scores.max().item()}
|
return {"label": self.model.config.id2label[scores.argmax().item()], "score": scores.max().item()}
|
||||||
|
|
||||||
|
dict_scores = [
|
||||||
|
{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
|
||||||
|
]
|
||||||
|
if not _legacy:
|
||||||
|
dict_scores.sort(key=lambda x: x["score"], reverse=True)
|
||||||
|
if top_k is not None:
|
||||||
|
dict_scores = dict_scores[:top_k]
|
||||||
|
return dict_scores
|
||||||
|
|||||||
@@ -39,6 +39,27 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
|
|||||||
outputs = text_classifier("This is great !")
|
outputs = text_classifier("This is great !")
|
||||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||||
|
|
||||||
|
outputs = text_classifier("This is great !", top_k=2)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = text_classifier(["This is great !", "This is bad"], top_k=2)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[
|
||||||
|
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
|
||||||
|
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = text_classifier("This is great !", top_k=1)
|
||||||
|
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||||
|
|
||||||
|
# Legacy behavior
|
||||||
|
outputs = text_classifier("This is great !", return_all_scores=False)
|
||||||
|
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_accepts_torch_device(self):
|
def test_accepts_torch_device(self):
|
||||||
import torch
|
import torch
|
||||||
@@ -108,6 +129,15 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
|
|||||||
self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
|
self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
|
||||||
self.assertTrue(outputs[1]["label"] in model.config.id2label.values())
|
self.assertTrue(outputs[1]["label"] in model.config.id2label.values())
|
||||||
|
|
||||||
|
# Forcing to get all results with `top_k=None`
|
||||||
|
# This is NOT the legacy format
|
||||||
|
outputs = text_classifier(valid_inputs, top_k=None)
|
||||||
|
N = len(model.config.id2label.values())
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[[{"label": ANY(str), "score": ANY(float)}] * N, [{"label": ANY(str), "score": ANY(float)}] * N],
|
||||||
|
)
|
||||||
|
|
||||||
valid_inputs = {"text": "HuggingFace is in ", "text_pair": "Paris is in France"}
|
valid_inputs = {"text": "HuggingFace is in ", "text_pair": "Paris is in France"}
|
||||||
outputs = text_classifier(valid_inputs)
|
outputs = text_classifier(valid_inputs)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|||||||
Reference in New Issue
Block a user