Adding top_k argument to text-classification pipeline. (#17606)

* Adding `top_k` and `sort` arguments to `text-classification` pipeline.

- Deprecate `return_all_scores` as `top_k` is more uniform with other
  pipelines, and a superset of what `return_all_scores` can do.
  BC is maintained though.
  `return_all_scores=True` -> `top_k=None`
  `return_all_scores=False` -> `top_k=1`

- Using `top_k` will imply sorting the results, but using no argument
  will keep the results unsorted for backward compatibility.

* Remove `sort`.

* Fixing the test.

* Remove bad doc.
This commit is contained in:
Nicolas Patry
2022-06-09 18:33:10 +02:00
committed by GitHub
parent 29080643eb
commit 2351729f7d
2 changed files with 63 additions and 11 deletions

View File

@@ -39,6 +39,27 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
outputs = text_classifier("This is great !", top_k=2)
self.assertEqual(
nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]
)
outputs = text_classifier(["This is great !", "This is bad"], top_k=2)
self.assertEqual(
nested_simplify(outputs),
[
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
],
)
outputs = text_classifier("This is great !", top_k=1)
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
# Legacy behavior
outputs = text_classifier("This is great !", return_all_scores=False)
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_torch
def test_accepts_torch_device(self):
import torch
@@ -108,6 +129,15 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
self.assertTrue(outputs[1]["label"] in model.config.id2label.values())
# Forcing to get all results with `top_k=None`
# This is NOT the legacy format
outputs = text_classifier(valid_inputs, top_k=None)
N = len(model.config.id2label.values())
self.assertEqual(
nested_simplify(outputs),
[[{"label": ANY(str), "score": ANY(float)}] * N, [{"label": ANY(str), "score": ANY(float)}] * N],
)
valid_inputs = {"text": "HuggingFace is in ", "text_pair": "Paris is in France"}
outputs = text_classifier(valid_inputs)
self.assertEqual(