Fixing a regression with `return_all_scores` introduced in #17606 - The legacy test actually tested `return_all_scores=False` (the actual default) instead of `return_all_scores=True` (the actual weird case). This commit adds the correct legacy test and fixes it. Tmp legacy tests. Actually fix the regression (also contains lists) Less diffed code.
This commit is contained in:
@@ -60,6 +60,29 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
|
||||
outputs = text_classifier("This is great !", return_all_scores=False)
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||
|
||||
outputs = text_classifier("This is great !", return_all_scores=True)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs), [[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]]
|
||||
)
|
||||
|
||||
outputs = text_classifier(["This is great !", "Something else"], return_all_scores=True)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
|
||||
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
|
||||
],
|
||||
)
|
||||
|
||||
outputs = text_classifier(["This is great !", "Something else"], return_all_scores=False)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[
|
||||
{"label": "LABEL_0", "score": 0.504},
|
||||
{"label": "LABEL_0", "score": 0.504},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_accepts_torch_device(self):
|
||||
import torch
|
||||
|
||||
Reference in New Issue
Block a user