Sync QuestionAnsweringPipeline (#34039)

* Sync QuestionAnsweringPipeline

* typo fixes

* Update deprecation warnings
This commit is contained in:
Matt
2024-10-10 13:38:14 +01:00
committed by GitHub
parent c9afee5392
commit f8a260e2a4
3 changed files with 24 additions and 11 deletions

View File

@@ -14,6 +14,8 @@
import unittest
from huggingface_hub import QuestionAnsweringOutputElement
from transformers import (
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
@@ -23,6 +25,7 @@ from transformers import (
from transformers.data.processors.squad import SquadExample
from transformers.pipelines import QuestionAnsweringArgumentHandler, pipeline
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
is_pipeline_test,
nested_simplify,
require_tf,
@@ -132,6 +135,8 @@ class QAPipelineTests(unittest.TestCase):
self.assertEqual(
outputs, [{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)} for i in range(20)]
)
for single_output in outputs:
compare_pipeline_output_to_hub_spec(single_output, QuestionAnsweringOutputElement)
# Very long context require multiple features
outputs = question_answerer(

View File

@@ -33,6 +33,7 @@ from huggingface_hub import (
ImageSegmentationInput,
ImageToTextInput,
ObjectDetectionInput,
QuestionAnsweringInput,
ZeroShotImageClassificationInput,
)
@@ -45,6 +46,7 @@ from transformers.pipelines import (
ImageSegmentationPipeline,
ImageToTextPipeline,
ObjectDetectionPipeline,
QuestionAnsweringPipeline,
ZeroShotImageClassificationPipeline,
)
from transformers.testing_utils import (
@@ -129,6 +131,7 @@ task_to_pipeline_and_spec_mapping = {
"image-segmentation": (ImageSegmentationPipeline, ImageSegmentationInput),
"image-to-text": (ImageToTextPipeline, ImageToTextInput),
"object-detection": (ObjectDetectionPipeline, ObjectDetectionInput),
"question-answering": (QuestionAnsweringPipeline, QuestionAnsweringInput),
"zero-shot-image-classification": (ZeroShotImageClassificationPipeline, ZeroShotImageClassificationInput),
}