Handle long answer needs to be updated. (#14279)
`start_` and `end_` tensors now contain a batch_size at this point.
This commit is contained in:
@@ -412,7 +412,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))
|
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))
|
||||||
|
|
||||||
if handle_impossible_answer:
|
if handle_impossible_answer:
|
||||||
min_null_score = min(min_null_score, (start_[0] * end_[0]).item())
|
min_null_score = min(min_null_score, (start_[0, 0] * end_[0, 0]).item())
|
||||||
|
|
||||||
# Mask CLS
|
# Mask CLS
|
||||||
start_[0, 0] = end_[0, 0] = 0.0
|
start_[0, 0] = end_[0, 0] = 0.0
|
||||||
|
|||||||
@@ -50,6 +50,12 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
|||||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
||||||
)
|
)
|
||||||
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||||
|
outputs = question_answerer(
|
||||||
|
question="Where was HuggingFace founded ?",
|
||||||
|
context="HuggingFace was founded in Paris.",
|
||||||
|
handle_impossible_answer=True,
|
||||||
|
)
|
||||||
|
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||||
|
|
||||||
outputs = question_answerer(
|
outputs = question_answerer(
|
||||||
question=["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
|
question=["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
|
||||||
|
|||||||
Reference in New Issue
Block a user