From 2351729f7d1c606624dcd2d1ad5dc8e627e17320 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 9 Jun 2022 18:33:10 +0200 Subject: [PATCH] Adding `top_k` argument to `text-classification` pipeline. (#17606) * Adding `top_k` and `sort` arguments to `text-classification` pipeline. - Deprecate `return_all_scores` as `top_k` is more uniform with other pipelines, and a superset of what `return_all_scores` can do. BC is maintained though. `return_all_scores=True` -> `top_k=None` `return_all_scores=False` -> `top_k=1` - Using `top_k` will imply sorting the results, but using no argument will keep the results unsorted for backward compatibility. * Remove `sort`. * Fixing the test. * Remove bad doc. --- .../pipelines/text_classification.py | 44 ++++++++++++++----- .../test_pipelines_text_classification.py | 30 +++++++++++++ 2 files changed, 63 insertions(+), 11 deletions(-) diff --git a/src/transformers/pipelines/text_classification.py b/src/transformers/pipelines/text_classification.py index bb705a9b40..590c87c022 100644 --- a/src/transformers/pipelines/text_classification.py +++ b/src/transformers/pipelines/text_classification.py @@ -1,3 +1,4 @@ +import warnings from typing import Dict import numpy as np @@ -72,15 +73,26 @@ class TextClassificationPipeline(Pipeline): else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING ) - def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, **tokenizer_kwargs): + def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, top_k="", **tokenizer_kwargs): + # Using "" as default argument because we're going to use `top_k=None` in user code to declare + # "No top_k" preprocess_params = tokenizer_kwargs postprocess_params = {} if hasattr(self.model.config, "return_all_scores") and return_all_scores is None: return_all_scores = self.model.config.return_all_scores - if return_all_scores is not None: - postprocess_params["return_all_scores"] = return_all_scores + if isinstance(top_k, int) or top_k is None: + postprocess_params["top_k"] = top_k + postprocess_params["_legacy"] = False + elif return_all_scores is not None: + warnings.warn( + "`return_all_scores` is now deprecated, use `top_k=1` if you want similar functionnality", UserWarning + ) + if return_all_scores: + postprocess_params["top_k"] = None + else: + postprocess_params["top_k"] = 1 if isinstance(function_to_apply, str): function_to_apply = ClassificationFunction[function_to_apply.upper()] @@ -97,8 +109,8 @@ class TextClassificationPipeline(Pipeline): args (`str` or `List[str]` or `Dict[str]`, or `List[Dict[str]]`): One or several texts to classify. In order to use text pairs for your classification, you can send a dictionnary containing `{"text", "text_pair"}` keys, or a list of those. - return_all_scores (`bool`, *optional*, defaults to `False`): - Whether to return scores for all labels. + top_k (`int`, *optional*, defaults to `1`): + How many results to return. function_to_apply (`str`, *optional*, defaults to `"default"`): The function to apply to the model outputs in order to retrieve the scores. Accepts four different values: @@ -121,10 +133,10 @@ class TextClassificationPipeline(Pipeline): - **label** (`str`) -- The label predicted. - **score** (`float`) -- The corresponding probability. - If `self.return_all_scores=True`, one such dictionary is returned per label. + If `top_k` is used, one such dictionary is returned per label. """ result = super().__call__(*args, **kwargs) - if isinstance(args[0], str): + if isinstance(args[0], str) and isinstance(result, dict): # This pipeline is odd, and return a list when single item is run return [result] else: @@ -150,7 +162,10 @@ class TextClassificationPipeline(Pipeline): def _forward(self, model_inputs): return self.model(**model_inputs) - def postprocess(self, model_outputs, function_to_apply=None, return_all_scores=False): + def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True): + # `_legacy` is used to determine if we're running the naked pipeline and in backward + # compatibility mode, or if running the pipeline with `pipeline(..., top_k=1)` we're running + # the more natural result containing the list. # Default value before `set_parameters` if function_to_apply is None: if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1: @@ -174,7 +189,14 @@ class TextClassificationPipeline(Pipeline): else: raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}") - if return_all_scores: - return [{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)] - else: + if top_k == 1 and _legacy: return {"label": self.model.config.id2label[scores.argmax().item()], "score": scores.max().item()} + + dict_scores = [ + {"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores) + ] + if not _legacy: + dict_scores.sort(key=lambda x: x["score"], reverse=True) + if top_k is not None: + dict_scores = dict_scores[:top_k] + return dict_scores diff --git a/tests/pipelines/test_pipelines_text_classification.py b/tests/pipelines/test_pipelines_text_classification.py index 2e62232957..9251b29922 100644 --- a/tests/pipelines/test_pipelines_text_classification.py +++ b/tests/pipelines/test_pipelines_text_classification.py @@ -39,6 +39,27 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC outputs = text_classifier("This is great !") self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + outputs = text_classifier("This is great !", top_k=2) + self.assertEqual( + nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}] + ) + + outputs = text_classifier(["This is great !", "This is bad"], top_k=2) + self.assertEqual( + nested_simplify(outputs), + [ + [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}], + [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}], + ], + ) + + outputs = text_classifier("This is great !", top_k=1) + self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + + # Legacy behavior + outputs = text_classifier("This is great !", return_all_scores=False) + self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + @require_torch def test_accepts_torch_device(self): import torch @@ -108,6 +129,15 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC self.assertTrue(outputs[0]["label"] in model.config.id2label.values()) self.assertTrue(outputs[1]["label"] in model.config.id2label.values()) + # Forcing to get all results with `top_k=None` + # This is NOT the legacy format + outputs = text_classifier(valid_inputs, top_k=None) + N = len(model.config.id2label.values()) + self.assertEqual( + nested_simplify(outputs), + [[{"label": ANY(str), "score": ANY(float)}] * N, [{"label": ANY(str), "score": ANY(float)}] * N], + ) + valid_inputs = {"text": "HuggingFace is in ", "text_pair": "Paris is in France"} outputs = text_classifier(valid_inputs) self.assertEqual(