Fix the behaviour of DefaultArgumentHandler (removing it). (#8180)
* Some work to fix the behaviour of DefaultArgumentHandler by removing it. * Fixing specific pipelines argument checking.
This commit is contained in:
@@ -23,9 +23,8 @@ import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from os.path import abspath, exists
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
@@ -185,57 +184,6 @@ class ArgumentHandler(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DefaultArgumentHandler(ArgumentHandler):
|
||||
"""
|
||||
Default argument parser handling parameters for each :class:`~transformers.pipelines.Pipeline`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def handle_kwargs(kwargs: Dict) -> List:
|
||||
if len(kwargs) == 1:
|
||||
output = list(kwargs.values())
|
||||
else:
|
||||
output = list(chain(kwargs.values()))
|
||||
|
||||
return DefaultArgumentHandler.handle_args(output)
|
||||
|
||||
@staticmethod
|
||||
def handle_args(args: Sequence[Any]) -> List[str]:
|
||||
|
||||
# Only one argument, let's do case by case
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], str):
|
||||
return [args[0]]
|
||||
elif not isinstance(args[0], list):
|
||||
return list(args)
|
||||
else:
|
||||
return args[0]
|
||||
|
||||
# Multiple arguments (x1, x2, ...)
|
||||
elif len(args) > 1:
|
||||
if all([isinstance(arg, str) for arg in args]):
|
||||
return list(args)
|
||||
|
||||
# If not instance of list, then it should instance of iterable
|
||||
elif isinstance(args, Iterable):
|
||||
return list(chain.from_iterable(chain(args)))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if len(kwargs) > 0 and len(args) > 0:
|
||||
raise ValueError("Pipeline cannot handle mixed args and kwargs")
|
||||
|
||||
if len(kwargs) > 0:
|
||||
return DefaultArgumentHandler.handle_kwargs(kwargs)
|
||||
else:
|
||||
return DefaultArgumentHandler.handle_args(args)
|
||||
|
||||
|
||||
class PipelineDataFormat:
|
||||
"""
|
||||
Base class for all the pipeline supported data format both for reading and writing. Supported data formats
|
||||
@@ -574,7 +522,6 @@ class Pipeline(_ScikitCompat):
|
||||
self.framework = framework
|
||||
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
|
||||
self.binary_output = binary_output
|
||||
self._args_parser = args_parser or DefaultArgumentHandler()
|
||||
|
||||
# Special handling
|
||||
if self.framework == "pt" and self.device.type == "cuda":
|
||||
@@ -669,12 +616,11 @@ class Pipeline(_ScikitCompat):
|
||||
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
|
||||
)
|
||||
|
||||
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
|
||||
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
|
||||
"""
|
||||
Parse arguments and tokenize
|
||||
"""
|
||||
# Parse arguments
|
||||
inputs = self._args_parser(*args, **kwargs)
|
||||
inputs = self.tokenizer(
|
||||
inputs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
@@ -836,7 +782,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
|
||||
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
||||
|
||||
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
|
||||
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
|
||||
"""
|
||||
Parse arguments and tokenize
|
||||
"""
|
||||
@@ -845,7 +791,6 @@ class TextGenerationPipeline(Pipeline):
|
||||
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
|
||||
else:
|
||||
tokenizer_kwargs = {}
|
||||
inputs = self._args_parser(*args, **kwargs)
|
||||
inputs = self.tokenizer(
|
||||
inputs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
@@ -858,7 +803,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args,
|
||||
text_inputs,
|
||||
return_tensors=False,
|
||||
return_text=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
@@ -890,7 +835,6 @@ class TextGenerationPipeline(Pipeline):
|
||||
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
||||
-- The token ids of the generated text.
|
||||
"""
|
||||
text_inputs = self._args_parser(*args)
|
||||
|
||||
results = []
|
||||
for prompt_text in text_inputs:
|
||||
@@ -1094,7 +1038,8 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
"""
|
||||
|
||||
def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
|
||||
super().__init__(*args, args_parser=args_parser, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
self._args_parser = args_parser
|
||||
if self.entailment_id == -1:
|
||||
logger.warning(
|
||||
"Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
|
||||
@@ -1108,13 +1053,15 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
return ind
|
||||
return -1
|
||||
|
||||
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
|
||||
def _parse_and_tokenize(
|
||||
self, sequences, candidal_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs
|
||||
):
|
||||
"""
|
||||
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
|
||||
"""
|
||||
inputs = self._args_parser(*args, **kwargs)
|
||||
sequence_pairs = self._args_parser(sequences, candidal_labels, hypothesis_template)
|
||||
inputs = self.tokenizer(
|
||||
inputs,
|
||||
sequence_pairs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=self.framework,
|
||||
padding=padding,
|
||||
@@ -1123,7 +1070,13 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
|
||||
return inputs
|
||||
|
||||
def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
|
||||
def __call__(
|
||||
self,
|
||||
sequences: Union[str, List[str]],
|
||||
candidate_labels,
|
||||
hypothesis_template="This example is {}.",
|
||||
multi_class=False,
|
||||
):
|
||||
"""
|
||||
Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline`
|
||||
documentation for more information.
|
||||
@@ -1154,8 +1107,11 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
- **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
|
||||
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
|
||||
"""
|
||||
if sequences and isinstance(sequences, str):
|
||||
sequences = [sequences]
|
||||
|
||||
outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
|
||||
num_sequences = 1 if isinstance(sequences, str) else len(sequences)
|
||||
num_sequences = len(sequences)
|
||||
candidate_labels = self._args_parser._parse_labels(candidate_labels)
|
||||
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))
|
||||
|
||||
@@ -1425,12 +1381,12 @@ class TokenClassificationPipeline(Pipeline):
|
||||
self.ignore_labels = ignore_labels
|
||||
self.grouped_entities = grouped_entities
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, inputs: Union[str, List[str]], **kwargs):
|
||||
"""
|
||||
Classify each token of the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
args (:obj:`str` or :obj:`List[str]`):
|
||||
inputs (:obj:`str` or :obj:`List[str]`):
|
||||
One or several texts (or one list of texts) for token classification.
|
||||
|
||||
Return:
|
||||
@@ -1444,7 +1400,8 @@ class TokenClassificationPipeline(Pipeline):
|
||||
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
|
||||
corresponding token in the sentence.
|
||||
"""
|
||||
inputs = self._args_parser(*args, **kwargs)
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
answers = []
|
||||
for sentence in inputs:
|
||||
|
||||
@@ -1659,12 +1616,12 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
tokenizer=tokenizer,
|
||||
modelcard=modelcard,
|
||||
framework=framework,
|
||||
args_parser=QuestionAnsweringArgumentHandler(),
|
||||
device=device,
|
||||
task=task,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._args_parser = QuestionAnsweringArgumentHandler()
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
)
|
||||
@@ -2489,12 +2446,11 @@ class ConversationalPipeline(Pipeline):
|
||||
else:
|
||||
return output
|
||||
|
||||
def _parse_and_tokenize(self, *args, **kwargs):
|
||||
def _parse_and_tokenize(self, inputs, **kwargs):
|
||||
"""
|
||||
Parse arguments and tokenize, adding an EOS token at the end of the user input
|
||||
"""
|
||||
# Parse arguments
|
||||
inputs = self._args_parser(*args, **kwargs)
|
||||
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
|
||||
for input in inputs:
|
||||
input.append(self.tokenizer.eos_token_id)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import unittest
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import is_tf_available, is_torch_available, pipeline
|
||||
from transformers.pipelines import DefaultArgumentHandler, Pipeline
|
||||
|
||||
# from transformers.pipelines import DefaultArgumentHandler, Pipeline
|
||||
from transformers.pipelines import Pipeline
|
||||
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
|
||||
|
||||
|
||||
@@ -200,74 +201,74 @@ class MonoInputPipelineCommonMixin:
|
||||
self.assertRaises(Exception, nlp, self.invalid_inputs)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class DefaultArgumentHandlerTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.handler = DefaultArgumentHandler()
|
||||
|
||||
def test_kwargs_x(self):
|
||||
mono_data = {"X": "This is a sample input"}
|
||||
mono_args = self.handler(**mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
|
||||
multi_args = self.handler(**multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
|
||||
def test_kwargs_data(self):
|
||||
mono_data = {"data": "This is a sample input"}
|
||||
mono_args = self.handler(**mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
|
||||
multi_args = self.handler(**multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
|
||||
def test_multi_kwargs(self):
|
||||
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
|
||||
mono_args = self.handler(**mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 2)
|
||||
|
||||
multi_data = {
|
||||
"data": ["This is a sample input", "This is a second sample input"],
|
||||
"test": ["This is a sample input 2", "This is a second sample input 2"],
|
||||
}
|
||||
multi_args = self.handler(**multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 4)
|
||||
|
||||
def test_args(self):
|
||||
mono_data = "This is a sample input"
|
||||
mono_args = self.handler(mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
mono_data = ["This is a sample input"]
|
||||
mono_args = self.handler(mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
multi_data = ["This is a sample input", "This is a second sample input"]
|
||||
multi_args = self.handler(multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
|
||||
multi_data = ["This is a sample input", "This is a second sample input"]
|
||||
multi_args = self.handler(*multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
# @is_pipeline_test
|
||||
# class DefaultArgumentHandlerTestCase(unittest.TestCase):
|
||||
# def setUp(self) -> None:
|
||||
# self.handler = DefaultArgumentHandler()
|
||||
#
|
||||
# def test_kwargs_x(self):
|
||||
# mono_data = {"X": "This is a sample input"}
|
||||
# mono_args = self.handler(**mono_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(mono_args, list))
|
||||
# self.assertEqual(len(mono_args), 1)
|
||||
#
|
||||
# multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
|
||||
# multi_args = self.handler(**multi_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(multi_args, list))
|
||||
# self.assertEqual(len(multi_args), 2)
|
||||
#
|
||||
# def test_kwargs_data(self):
|
||||
# mono_data = {"data": "This is a sample input"}
|
||||
# mono_args = self.handler(**mono_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(mono_args, list))
|
||||
# self.assertEqual(len(mono_args), 1)
|
||||
#
|
||||
# multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
|
||||
# multi_args = self.handler(**multi_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(multi_args, list))
|
||||
# self.assertEqual(len(multi_args), 2)
|
||||
#
|
||||
# def test_multi_kwargs(self):
|
||||
# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
|
||||
# mono_args = self.handler(**mono_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(mono_args, list))
|
||||
# self.assertEqual(len(mono_args), 2)
|
||||
#
|
||||
# multi_data = {
|
||||
# "data": ["This is a sample input", "This is a second sample input"],
|
||||
# "test": ["This is a sample input 2", "This is a second sample input 2"],
|
||||
# }
|
||||
# multi_args = self.handler(**multi_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(multi_args, list))
|
||||
# self.assertEqual(len(multi_args), 4)
|
||||
#
|
||||
# def test_args(self):
|
||||
# mono_data = "This is a sample input"
|
||||
# mono_args = self.handler(mono_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(mono_args, list))
|
||||
# self.assertEqual(len(mono_args), 1)
|
||||
#
|
||||
# mono_data = ["This is a sample input"]
|
||||
# mono_args = self.handler(mono_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(mono_args, list))
|
||||
# self.assertEqual(len(mono_args), 1)
|
||||
#
|
||||
# multi_data = ["This is a sample input", "This is a second sample input"]
|
||||
# multi_args = self.handler(multi_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(multi_args, list))
|
||||
# self.assertEqual(len(multi_args), 2)
|
||||
#
|
||||
# multi_data = ["This is a sample input", "This is a second sample input"]
|
||||
# multi_args = self.handler(*multi_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(multi_args, list))
|
||||
# self.assertEqual(len(multi_args), 2)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.testing_utils import require_tf, require_torch, slow
|
||||
|
||||
@@ -37,7 +39,7 @@ EXPECTED_FILL_MASK_TARGET_RESULT = [
|
||||
|
||||
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "fill-mask"
|
||||
pipeline_loading_kwargs = {"topk": 2}
|
||||
pipeline_loading_kwargs = {"top_k": 2}
|
||||
small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator
|
||||
large_models = ["distilroberta-base"] # Models tested with the @slow decorator
|
||||
mandatory_keys = {"sequence", "score", "token"}
|
||||
@@ -51,6 +53,28 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
]
|
||||
expected_check_keys = ["sequence"]
|
||||
|
||||
@require_torch
|
||||
def test_torch_topk_deprecation(self):
|
||||
# At pipeline initialization only it was not enabled at pipeline
|
||||
# call site before
|
||||
with pytest.warns(FutureWarning, match=r".*use `top_k`.*"):
|
||||
pipeline(task="fill-mask", model=self.small_models[0], topk=1)
|
||||
|
||||
@require_torch
|
||||
def test_torch_fill_mask(self):
|
||||
valid_inputs = "My name is <mask>"
|
||||
nlp = pipeline(task="fill-mask", model=self.small_models[0])
|
||||
outputs = nlp(valid_inputs)
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
# This passes
|
||||
outputs = nlp(valid_inputs, targets=[" Patrick", " Clara"])
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
# This used to fail with `cannot mix args and kwargs`
|
||||
outputs = nlp(valid_inputs, something=False)
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
@require_torch
|
||||
def test_torch_fill_mask_with_targets(self):
|
||||
valid_inputs = ["My name is <mask>"]
|
||||
@@ -94,7 +118,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
topk=2,
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
mono_result = nlp(valid_inputs[0], targets=valid_targets)
|
||||
|
||||
@@ -17,7 +17,7 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
|
||||
sum = 0.0
|
||||
for score in result["scores"]:
|
||||
sum += score
|
||||
self.assertAlmostEqual(sum, 1.0)
|
||||
self.assertAlmostEqual(sum, 1.0, places=5)
|
||||
|
||||
def _test_entailment_id(self, nlp: Pipeline):
|
||||
config = nlp.model.config
|
||||
|
||||
Reference in New Issue
Block a user