@@ -540,7 +540,13 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
min_null_score = 1000000 # large and positive
|
min_null_score = 1000000 # large and positive
|
||||||
answers = []
|
answers = []
|
||||||
for output in model_outputs:
|
for output in model_outputs:
|
||||||
|
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
|
||||||
|
start_ = output["start"].to(torch.float32)
|
||||||
|
else:
|
||||||
start_ = output["start"]
|
start_ = output["start"]
|
||||||
|
if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
|
||||||
|
end_ = output["end"].to(torch.float32)
|
||||||
|
else:
|
||||||
end_ = output["end"]
|
end_ = output["end"]
|
||||||
example = output["example"]
|
example = output["example"]
|
||||||
p_mask = output["p_mask"]
|
p_mask = output["p_mask"]
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from transformers.pipelines import QuestionAnsweringArgumentHandler, pipeline
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
compare_pipeline_output_to_hub_spec,
|
compare_pipeline_output_to_hub_spec,
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
|
is_torch_available,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -34,6 +35,10 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
from .test_pipelines_common import ANY
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
|
|
||||||
@@ -165,6 +170,34 @@ class QAPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_fp16(self):
|
||||||
|
question_answerer = pipeline(
|
||||||
|
"question-answering",
|
||||||
|
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = question_answerer(
|
||||||
|
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_bf16(self):
|
||||||
|
question_answerer = pipeline(
|
||||||
|
"question-answering",
|
||||||
|
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = question_answerer(
|
||||||
|
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt_iterator(self):
|
def test_small_model_pt_iterator(self):
|
||||||
# https://github.com/huggingface/transformers/issues/18510
|
# https://github.com/huggingface/transformers/issues/18510
|
||||||
|
|||||||
Reference in New Issue
Block a user