Fixing a regression with return_all_scores introduced in #17606 (#17906)

Fixing a regression with `return_all_scores` introduced in #17606

- The legacy test actually tested `return_all_scores=False` (the actual
  default) instead of `return_all_scores=True` (the actual weird case).

This commit adds the correct legacy test and fixes it.

Tmp legacy tests.

Actually fix the regression (also contains lists)

Less diffed code.
This commit is contained in:
Nicolas Patry
2022-06-28 17:24:45 -04:00
committed by GitHub
parent 5f1e67a566
commit 776855c752
2 changed files with 26 additions and 1 deletions

View File

@@ -136,7 +136,9 @@ class TextClassificationPipeline(Pipeline):
If `top_k` is used, one such dictionary is returned per label. If `top_k` is used, one such dictionary is returned per label.
""" """
result = super().__call__(*args, **kwargs) result = super().__call__(*args, **kwargs)
if isinstance(args[0], str) and isinstance(result, dict): # TODO try and retrieve it in a nicer way from _sanitize_parameters.
_legacy = "top_k" not in kwargs
if isinstance(args[0], str) and _legacy:
# This pipeline is odd, and return a list when single item is run # This pipeline is odd, and return a list when single item is run
return [result] return [result]
else: else:

View File

@@ -60,6 +60,29 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
outputs = text_classifier("This is great !", return_all_scores=False) outputs = text_classifier("This is great !", return_all_scores=False)
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
outputs = text_classifier("This is great !", return_all_scores=True)
self.assertEqual(
nested_simplify(outputs), [[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]]
)
outputs = text_classifier(["This is great !", "Something else"], return_all_scores=True)
self.assertEqual(
nested_simplify(outputs),
[
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
],
)
outputs = text_classifier(["This is great !", "Something else"], return_all_scores=False)
self.assertEqual(
nested_simplify(outputs),
[
{"label": "LABEL_0", "score": 0.504},
{"label": "LABEL_0", "score": 0.504},
],
)
@require_torch @require_torch
def test_accepts_torch_device(self): def test_accepts_torch_device(self):
import torch import torch