Fixing a regression with `return_all_scores` introduced in #17606 - The legacy test actually tested `return_all_scores=False` (the actual default) instead of `return_all_scores=True` (the actual weird case). This commit adds the correct legacy test and fixes it. Tmp legacy tests. Actually fix the regression (also contains lists) Less diffed code.
This commit is contained in:
@@ -136,7 +136,9 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
If `top_k` is used, one such dictionary is returned per label.
|
If `top_k` is used, one such dictionary is returned per label.
|
||||||
"""
|
"""
|
||||||
result = super().__call__(*args, **kwargs)
|
result = super().__call__(*args, **kwargs)
|
||||||
if isinstance(args[0], str) and isinstance(result, dict):
|
# TODO try and retrieve it in a nicer way from _sanitize_parameters.
|
||||||
|
_legacy = "top_k" not in kwargs
|
||||||
|
if isinstance(args[0], str) and _legacy:
|
||||||
# This pipeline is odd, and return a list when single item is run
|
# This pipeline is odd, and return a list when single item is run
|
||||||
return [result]
|
return [result]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -60,6 +60,29 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
|
|||||||
outputs = text_classifier("This is great !", return_all_scores=False)
|
outputs = text_classifier("This is great !", return_all_scores=False)
|
||||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||||
|
|
||||||
|
outputs = text_classifier("This is great !", return_all_scores=True)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs), [[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]]
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = text_classifier(["This is great !", "Something else"], return_all_scores=True)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[
|
||||||
|
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
|
||||||
|
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = text_classifier(["This is great !", "Something else"], return_all_scores=False)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[
|
||||||
|
{"label": "LABEL_0", "score": 0.504},
|
||||||
|
{"label": "LABEL_0", "score": 0.504},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_accepts_torch_device(self):
|
def test_accepts_torch_device(self):
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
Reference in New Issue
Block a user