From f8a260e2a44fbc707878277cb8cb5e53619f8b74 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 10 Oct 2024 13:38:14 +0100 Subject: [PATCH] Sync QuestionAnsweringPipeline (#34039) * Sync QuestionAnsweringPipeline * typo fixes * Update deprecation warnings --- .../pipelines/question_answering.py | 27 +++++++++++-------- .../test_pipelines_question_answering.py | 5 ++++ tests/test_pipeline_mixin.py | 3 +++ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index 4ac5d252b1..6039e5ad1e 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -183,8 +183,16 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler): # Generic compatibility with sklearn and Keras # Batched data elif "X" in kwargs: + warnings.warn( + "Passing the `X` argument to the pipeline is deprecated and will be removed in v5. Inputs should be passed using the `question` and `context` keyword arguments instead.", + FutureWarning, + ) inputs = kwargs["X"] elif "data" in kwargs: + warnings.warn( + "Passing the `data` argument to the pipeline is deprecated and will be removed in v5. Inputs should be passed using the `question` and `context` keyword arguments instead.", + FutureWarning, + ) inputs = kwargs["data"] elif "question" in kwargs and "context" in kwargs: if isinstance(kwargs["question"], list) and isinstance(kwargs["context"], str): @@ -345,22 +353,14 @@ class QuestionAnsweringPipeline(ChunkPipeline): Answer the question(s) given as inputs by using the context(s). Args: - args ([`SquadExample`] or a list of [`SquadExample`]): - One or several [`SquadExample`] containing the question and context. - X ([`SquadExample`] or a list of [`SquadExample`], *optional*): - One or several [`SquadExample`] containing the question and context (will be treated the same way as if - passed as the first positional argument). - data ([`SquadExample`] or a list of [`SquadExample`], *optional*): - One or several [`SquadExample`] containing the question and context (will be treated the same way as if - passed as the first positional argument). question (`str` or `List[str]`): One or several question(s) (must be used in conjunction with the `context` argument). context (`str` or `List[str]`): One or several context(s) associated with the question(s) (must be used in conjunction with the `question` argument). - topk (`int`, *optional*, defaults to 1): + top_k (`int`, *optional*, defaults to 1): The number of answers to return (will be chosen by order of likelihood). Note that we return less than - topk answers if there are not enough options available within the context. + top_k answers if there are not enough options available within the context. doc_stride (`int`, *optional*, defaults to 128): If the context is too long to fit with the question for the model, it will be split in several chunks with some overlap. This argument controls the size of that overlap. @@ -374,7 +374,7 @@ class QuestionAnsweringPipeline(ChunkPipeline): handle_impossible_answer (`bool`, *optional*, defaults to `False`): Whether or not we accept impossible as an answer. align_to_words (`bool`, *optional*, defaults to `True`): - Attempts to align the answer to real words. Improves quality on space separated langages. Might hurt on + Attempts to align the answer to real words. Improves quality on space separated languages. Might hurt on non-space-separated languages (like Japanese or Chinese) Return: @@ -387,6 +387,11 @@ class QuestionAnsweringPipeline(ChunkPipeline): """ # Convert inputs to features + if args: + warnings.warn( + "Passing a list of SQuAD examples to the pipeline is deprecated and will be removed in v5. Inputs should be passed using the `question` and `context` keyword arguments instead.", + FutureWarning, + ) examples = self._args_parser(*args, **kwargs) if isinstance(examples, (list, tuple)) and len(examples) == 1: diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index d051a1435b..d06f88d1f0 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -14,6 +14,8 @@ import unittest +from huggingface_hub import QuestionAnsweringOutputElement + from transformers import ( MODEL_FOR_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, @@ -23,6 +25,7 @@ from transformers import ( from transformers.data.processors.squad import SquadExample from transformers.pipelines import QuestionAnsweringArgumentHandler, pipeline from transformers.testing_utils import ( + compare_pipeline_output_to_hub_spec, is_pipeline_test, nested_simplify, require_tf, @@ -132,6 +135,8 @@ class QAPipelineTests(unittest.TestCase): self.assertEqual( outputs, [{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)} for i in range(20)] ) + for single_output in outputs: + compare_pipeline_output_to_hub_spec(single_output, QuestionAnsweringOutputElement) # Very long context require multiple features outputs = question_answerer( diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py index 19f503dc6e..cae285f5f1 100644 --- a/tests/test_pipeline_mixin.py +++ b/tests/test_pipeline_mixin.py @@ -33,6 +33,7 @@ from huggingface_hub import ( ImageSegmentationInput, ImageToTextInput, ObjectDetectionInput, + QuestionAnsweringInput, ZeroShotImageClassificationInput, ) @@ -45,6 +46,7 @@ from transformers.pipelines import ( ImageSegmentationPipeline, ImageToTextPipeline, ObjectDetectionPipeline, + QuestionAnsweringPipeline, ZeroShotImageClassificationPipeline, ) from transformers.testing_utils import ( @@ -129,6 +131,7 @@ task_to_pipeline_and_spec_mapping = { "image-segmentation": (ImageSegmentationPipeline, ImageSegmentationInput), "image-to-text": (ImageToTextPipeline, ImageToTextInput), "object-detection": (ObjectDetectionPipeline, ObjectDetectionInput), + "question-answering": (QuestionAnsweringPipeline, QuestionAnsweringInput), "zero-shot-image-classification": (ZeroShotImageClassificationPipeline, ZeroShotImageClassificationInput), }