From 471cf1de633b935bb01d2c9d02ae1bdb50f86876 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B9=9B=E9=9C=B2=E5=85=88=E7=94=9F?= Date: Thu, 27 Mar 2025 22:35:33 +0800 Subject: [PATCH] clean pipeline question_answering. (#36986) Signed-off-by: zhanluxianshen --- src/transformers/pipelines/question_answering.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index eee05b9f2c..85d2e7c15d 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -340,7 +340,6 @@ class QuestionAnsweringPipeline(ChunkPipeline): if max_answer_len is not None: if max_answer_len < 1: raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}") - if max_answer_len is not None: postprocess_params["max_answer_len"] = max_answer_len if handle_impossible_answer is not None: postprocess_params["handle_impossible_answer"] = handle_impossible_answer @@ -542,11 +541,9 @@ class QuestionAnsweringPipeline(ChunkPipeline): for output in model_outputs: if self.framework == "pt" and output["start"].dtype == torch.bfloat16: start_ = output["start"].to(torch.float32) - else: - start_ = output["start"] - if self.framework == "pt" and output["start"].dtype == torch.bfloat16: end_ = output["end"].to(torch.float32) else: + start_ = output["start"] end_ = output["end"] example = output["example"] p_mask = output["p_mask"]