From 1cdd2ad2afb73f6af185aafecb7dd7941a90c4d1 Mon Sep 17 00:00:00 2001 From: Zhiyu Lin Date: Sat, 2 May 2020 11:20:30 -0400 Subject: [PATCH] Fix #2941 (#4109) * Fix of issue #2941 Reshaped score array to avoid `numpy` ValueError. * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py Co-authored-by: Julien Chaumond --- src/transformers/pipelines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 5871890b9f..582d3fffda 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -656,8 +656,8 @@ class TextClassificationPipeline(Pipeline): def __call__(self, *args, **kwargs): outputs = super().__call__(*args, **kwargs) - scores = np.exp(outputs) / np.exp(outputs).sum(-1) - return [{"label": self.model.config.id2label[item.argmax()], "score": item.max()} for item in scores] + scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True) + return [{"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores] class FillMaskPipeline(Pipeline):