clean pipeline question_answering. (#36986)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
This commit is contained in:
@@ -340,7 +340,6 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
||||
if max_answer_len is not None:
|
||||
if max_answer_len < 1:
|
||||
raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}")
|
||||
if max_answer_len is not None:
|
||||
postprocess_params["max_answer_len"] = max_answer_len
|
||||
if handle_impossible_answer is not None:
|
||||
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
|
||||
@@ -542,11 +541,9 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
||||
for output in model_outputs:
|
||||
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
|
||||
start_ = output["start"].to(torch.float32)
|
||||
else:
|
||||
start_ = output["start"]
|
||||
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
|
||||
end_ = output["end"].to(torch.float32)
|
||||
else:
|
||||
start_ = output["start"]
|
||||
end_ = output["end"]
|
||||
example = output["example"]
|
||||
p_mask = output["p_mask"]
|
||||
|
||||
Reference in New Issue
Block a user