clean pipeline question_answering. (#36986)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
This commit is contained in:
@@ -340,7 +340,6 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
if max_answer_len is not None:
|
if max_answer_len is not None:
|
||||||
if max_answer_len < 1:
|
if max_answer_len < 1:
|
||||||
raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}")
|
raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}")
|
||||||
if max_answer_len is not None:
|
|
||||||
postprocess_params["max_answer_len"] = max_answer_len
|
postprocess_params["max_answer_len"] = max_answer_len
|
||||||
if handle_impossible_answer is not None:
|
if handle_impossible_answer is not None:
|
||||||
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
|
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
|
||||||
@@ -542,11 +541,9 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
for output in model_outputs:
|
for output in model_outputs:
|
||||||
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
|
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
|
||||||
start_ = output["start"].to(torch.float32)
|
start_ = output["start"].to(torch.float32)
|
||||||
else:
|
|
||||||
start_ = output["start"]
|
|
||||||
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
|
|
||||||
end_ = output["end"].to(torch.float32)
|
end_ = output["end"].to(torch.float32)
|
||||||
else:
|
else:
|
||||||
|
start_ = output["start"]
|
||||||
end_ = output["end"]
|
end_ = output["end"]
|
||||||
example = output["example"]
|
example = output["example"]
|
||||||
p_mask = output["p_mask"]
|
p_mask = output["p_mask"]
|
||||||
|
|||||||
Reference in New Issue
Block a user