Adding top_k argument to text-classification pipeline. (#17606)
* Adding `top_k` and `sort` arguments to `text-classification` pipeline. - Deprecate `return_all_scores` as `top_k` is more uniform with other pipelines, and a superset of what `return_all_scores` can do. BC is maintained though. `return_all_scores=True` -> `top_k=None` `return_all_scores=False` -> `top_k=1` - Using `top_k` will imply sorting the results, but using no argument will keep the results unsorted for backward compatibility. * Remove `sort`. * Fixing the test. * Remove bad doc.
This commit is contained in:
@@ -39,6 +39,27 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
|
||||
outputs = text_classifier("This is great !")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||
|
||||
outputs = text_classifier("This is great !", top_k=2)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]
|
||||
)
|
||||
|
||||
outputs = text_classifier(["This is great !", "This is bad"], top_k=2)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
|
||||
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
|
||||
],
|
||||
)
|
||||
|
||||
outputs = text_classifier("This is great !", top_k=1)
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||
|
||||
# Legacy behavior
|
||||
outputs = text_classifier("This is great !", return_all_scores=False)
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||
|
||||
@require_torch
|
||||
def test_accepts_torch_device(self):
|
||||
import torch
|
||||
@@ -108,6 +129,15 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
|
||||
self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
|
||||
self.assertTrue(outputs[1]["label"] in model.config.id2label.values())
|
||||
|
||||
# Forcing to get all results with `top_k=None`
|
||||
# This is NOT the legacy format
|
||||
outputs = text_classifier(valid_inputs, top_k=None)
|
||||
N = len(model.config.id2label.values())
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[[{"label": ANY(str), "score": ANY(float)}] * N, [{"label": ANY(str), "score": ANY(float)}] * N],
|
||||
)
|
||||
|
||||
valid_inputs = {"text": "HuggingFace is in ", "text_pair": "Paris is in France"}
|
||||
outputs = text_classifier(valid_inputs)
|
||||
self.assertEqual(
|
||||
|
||||
Reference in New Issue
Block a user