Rewritten batch support in pipelines. (#4154)
* Rewritten batch support in pipelines. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Fix imports sorting 🔧 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Set pad_to_max_length=True by default on Pipeline. * Set pad_to_max_length=False for generation pipelines. Most of generation models doesn't have padding token. * Address @joeddav review comment: Uniformized *args. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Address @joeddav review comment: Uniformized *args (second). Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
@@ -22,8 +22,9 @@ import pickle
|
|||||||
import sys
|
import sys
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from itertools import chain
|
||||||
from os.path import abspath, exists
|
from os.path import abspath, exists
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -96,19 +97,50 @@ class DefaultArgumentHandler(ArgumentHandler):
|
|||||||
Default varargs argument parser handling parameters for each Pipeline
|
Default varargs argument parser handling parameters for each Pipeline
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
@staticmethod
|
||||||
if "X" in kwargs:
|
def handle_kwargs(kwargs: Dict) -> List:
|
||||||
return kwargs["X"]
|
if len(kwargs) == 1:
|
||||||
elif "data" in kwargs:
|
output = list(kwargs.values())
|
||||||
return kwargs["data"]
|
|
||||||
elif len(args) == 1:
|
|
||||||
if isinstance(args[0], list):
|
|
||||||
return args[0]
|
|
||||||
else:
|
else:
|
||||||
|
output = list(chain(kwargs.values()))
|
||||||
|
|
||||||
|
return DefaultArgumentHandler.handle_args(output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def handle_args(args: Sequence[Any]) -> List[str]:
|
||||||
|
|
||||||
|
# Only one argument, let's do case by case
|
||||||
|
if len(args) == 1:
|
||||||
|
if isinstance(args[0], str):
|
||||||
return [args[0]]
|
return [args[0]]
|
||||||
elif len(args) > 1:
|
elif not isinstance(args[0], list):
|
||||||
return list(args)
|
return list(args)
|
||||||
raise ValueError("Unable to infer the format of the provided data (X=, data=, ...)")
|
else:
|
||||||
|
return args[0]
|
||||||
|
|
||||||
|
# Multiple arguments (x1, x2, ...)
|
||||||
|
elif len(args) > 1:
|
||||||
|
if all([isinstance(arg, str) for arg in args]):
|
||||||
|
return list(args)
|
||||||
|
|
||||||
|
# If not instance of list, then it should instance of iterable
|
||||||
|
elif isinstance(args, Iterable):
|
||||||
|
return list(chain.from_iterable(chain(args)))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if len(kwargs) > 0 and len(args) > 0:
|
||||||
|
raise ValueError("Pipeline cannot handle mixed args and kwargs")
|
||||||
|
|
||||||
|
if len(kwargs) > 0:
|
||||||
|
return DefaultArgumentHandler.handle_kwargs(kwargs)
|
||||||
|
else:
|
||||||
|
return DefaultArgumentHandler.handle_args(args)
|
||||||
|
|
||||||
|
|
||||||
class PipelineDataFormat:
|
class PipelineDataFormat:
|
||||||
@@ -418,20 +450,20 @@ class Pipeline(_ScikitCompat):
|
|||||||
"""
|
"""
|
||||||
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
||||||
|
|
||||||
def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs):
|
def _parse_and_tokenize(self, *args, pad_to_max_length=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize
|
Parse arguments and tokenize
|
||||||
"""
|
"""
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
inputs = self._args_parser(*texts, **kwargs)
|
inputs = self._args_parser(*args, **kwargs)
|
||||||
inputs = self.tokenizer.batch_encode_plus(
|
inputs = self.tokenizer.batch_encode_plus(
|
||||||
inputs, add_special_tokens=True, return_tensors=self.framework, pad_to_max_length=pad_to_max_length,
|
inputs, add_special_tokens=True, return_tensors=self.framework, pad_to_max_length=pad_to_max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def __call__(self, *texts, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
inputs = self._parse_and_tokenize(*texts, **kwargs)
|
inputs = self._parse_and_tokenize(*args, **kwargs)
|
||||||
return self._forward(inputs)
|
return self._forward(inputs)
|
||||||
|
|
||||||
def _forward(self, inputs, return_tensors=False):
|
def _forward(self, inputs, return_tensors=False):
|
||||||
@@ -550,18 +582,18 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||||
):
|
):
|
||||||
text_inputs = self._args_parser(*texts)
|
text_inputs = self._args_parser(*args)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for prompt_text in text_inputs:
|
for prompt_text in text_inputs:
|
||||||
# Manage correct placement of the tensors
|
# Manage correct placement of the tensors
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
|
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
|
||||||
inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text)
|
inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text, pad_to_max_length=False)
|
||||||
else:
|
else:
|
||||||
inputs = self._parse_and_tokenize(prompt_text)
|
inputs = self._parse_and_tokenize(prompt_text, pad_to_max_length=False)
|
||||||
|
|
||||||
# set input_ids to None to allow empty prompt
|
# set input_ids to None to allow empty prompt
|
||||||
if inputs["input_ids"].shape[-1] == 0:
|
if inputs["input_ids"].shape[-1] == 0:
|
||||||
@@ -825,8 +857,8 @@ class NerPipeline(Pipeline):
|
|||||||
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
||||||
self.ignore_labels = ignore_labels
|
self.ignore_labels = ignore_labels
|
||||||
|
|
||||||
def __call__(self, *texts, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
inputs = self._args_parser(*texts, **kwargs)
|
inputs = self._args_parser(*args, **kwargs)
|
||||||
answers = []
|
answers = []
|
||||||
for sentence in inputs:
|
for sentence in inputs:
|
||||||
|
|
||||||
@@ -1016,7 +1048,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
return SquadExample(None, question, context, None, None, None)
|
return SquadExample(None, question, context, None, None, None)
|
||||||
|
|
||||||
def __call__(self, *texts, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
We support multiple use-cases, the following are exclusive:
|
We support multiple use-cases, the following are exclusive:
|
||||||
@@ -1046,7 +1078,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"]))
|
raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"]))
|
||||||
|
|
||||||
# Convert inputs to features
|
# Convert inputs to features
|
||||||
examples = self._args_parser(*texts, **kwargs)
|
examples = self._args_parser(*args, **kwargs)
|
||||||
features_list = [
|
features_list = [
|
||||||
squad_convert_examples_to_features(
|
squad_convert_examples_to_features(
|
||||||
[example],
|
[example],
|
||||||
@@ -1383,11 +1415,11 @@ class TranslationPipeline(Pipeline):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
*texts: (list of strings) texts to be translated
|
*args: (list of strings) texts to be translated
|
||||||
return_text: (bool, default=True) whether to add a decoded "translation_text" to each result
|
return_text: (bool, default=True) whether to add a decoded "translation_text" to each result
|
||||||
return_tensors: (bool, default=False) whether to return the raw "translation_token_ids" to each result
|
return_tensors: (bool, default=False) whether to return the raw "translation_token_ids" to each result
|
||||||
|
|
||||||
@@ -1402,25 +1434,25 @@ class TranslationPipeline(Pipeline):
|
|||||||
|
|
||||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||||
|
|
||||||
if isinstance(texts[0], list):
|
if isinstance(args[0], list):
|
||||||
assert (
|
assert (
|
||||||
self.tokenizer.pad_token_id is not None
|
self.tokenizer.pad_token_id is not None
|
||||||
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
||||||
texts = ([prefix + text for text in texts[0]],)
|
args = ([prefix + text for text in args[0]],)
|
||||||
pad_to_max_length = True
|
pad_to_max_length = True
|
||||||
|
|
||||||
elif isinstance(texts[0], str):
|
elif isinstance(args[0], str):
|
||||||
texts = (prefix + texts[0],)
|
args = (prefix + args[0],)
|
||||||
pad_to_max_length = False
|
pad_to_max_length = False
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
|
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
|
||||||
texts[0]
|
args[0]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
inputs = self._parse_and_tokenize(*texts, pad_to_max_length=pad_to_max_length)
|
inputs = self._parse_and_tokenize(*args, pad_to_max_length=pad_to_max_length)
|
||||||
|
|
||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
inputs = self.ensure_tensor_on_device(**inputs)
|
inputs = self.ensure_tensor_on_device(**inputs)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import unittest
|
|||||||
from typing import Iterable, List, Optional
|
from typing import Iterable, List, Optional
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
from transformers.pipelines import Pipeline
|
from transformers.pipelines import DefaultArgumentHandler, Pipeline
|
||||||
|
|
||||||
from .utils import require_tf, require_torch, slow
|
from .utils import require_tf, require_torch, slow
|
||||||
|
|
||||||
@@ -86,6 +86,78 @@ TRANSLATION_FINETUNED_MODELS = {
|
|||||||
TF_TRANSLATION_FINETUNED_MODELS = {("patrickvonplaten/t5-tiny-random", "t5-small", "translation_en_to_fr")}
|
TF_TRANSLATION_FINETUNED_MODELS = {("patrickvonplaten/t5-tiny-random", "t5-small", "translation_en_to_fr")}
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultArgumentHandlerTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.handler = DefaultArgumentHandler()
|
||||||
|
|
||||||
|
def test_kwargs_x(self):
|
||||||
|
mono_data = {"X": "This is a sample input"}
|
||||||
|
mono_args = self.handler(**mono_data)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(mono_args, list))
|
||||||
|
self.assertEqual(len(mono_args), 1)
|
||||||
|
|
||||||
|
multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
|
||||||
|
multi_args = self.handler(**multi_data)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(multi_args, list))
|
||||||
|
self.assertEqual(len(multi_args), 2)
|
||||||
|
|
||||||
|
def test_kwargs_data(self):
|
||||||
|
mono_data = {"data": "This is a sample input"}
|
||||||
|
mono_args = self.handler(**mono_data)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(mono_args, list))
|
||||||
|
self.assertEqual(len(mono_args), 1)
|
||||||
|
|
||||||
|
multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
|
||||||
|
multi_args = self.handler(**multi_data)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(multi_args, list))
|
||||||
|
self.assertEqual(len(multi_args), 2)
|
||||||
|
|
||||||
|
def test_multi_kwargs(self):
|
||||||
|
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
|
||||||
|
mono_args = self.handler(**mono_data)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(mono_args, list))
|
||||||
|
self.assertEqual(len(mono_args), 2)
|
||||||
|
|
||||||
|
multi_data = {
|
||||||
|
"data": ["This is a sample input", "This is a second sample input"],
|
||||||
|
"test": ["This is a sample input 2", "This is a second sample input 2"],
|
||||||
|
}
|
||||||
|
multi_args = self.handler(**multi_data)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(multi_args, list))
|
||||||
|
self.assertEqual(len(multi_args), 4)
|
||||||
|
|
||||||
|
def test_args(self):
|
||||||
|
mono_data = "This is a sample input"
|
||||||
|
mono_args = self.handler(mono_data)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(mono_args, list))
|
||||||
|
self.assertEqual(len(mono_args), 1)
|
||||||
|
|
||||||
|
mono_data = ["This is a sample input"]
|
||||||
|
mono_args = self.handler(mono_data)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(mono_args, list))
|
||||||
|
self.assertEqual(len(mono_args), 1)
|
||||||
|
|
||||||
|
multi_data = ["This is a sample input", "This is a second sample input"]
|
||||||
|
multi_args = self.handler(multi_data)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(multi_args, list))
|
||||||
|
self.assertEqual(len(multi_args), 2)
|
||||||
|
|
||||||
|
multi_data = ["This is a sample input", "This is a second sample input"]
|
||||||
|
multi_args = self.handler(*multi_data)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(multi_args, list))
|
||||||
|
self.assertEqual(len(multi_args), 2)
|
||||||
|
|
||||||
|
|
||||||
class MonoColumnInputTestCase(unittest.TestCase):
|
class MonoColumnInputTestCase(unittest.TestCase):
|
||||||
def _test_mono_column_pipeline(
|
def _test_mono_column_pipeline(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user