From 0a6cbea0a5e130ea6f935588f2a66b89b8aa9684 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Thu, 7 May 2020 13:52:40 +0000 Subject: [PATCH] Rewritten batch support in pipelines. (#4154) * Rewritten batch support in pipelines. Signed-off-by: Morgan Funtowicz * Fix imports sorting :wrench: Signed-off-by: Morgan Funtowicz * Set pad_to_max_length=True by default on Pipeline. * Set pad_to_max_length=False for generation pipelines. Most of generation models doesn't have padding token. * Address @joeddav review comment: Uniformized *args. Signed-off-by: Morgan Funtowicz * Address @joeddav review comment: Uniformized *args (second). Signed-off-by: Morgan Funtowicz --- src/transformers/pipelines.py | 96 +++++++++++++++++++++++------------ tests/test_pipelines.py | 74 ++++++++++++++++++++++++++- 2 files changed, 137 insertions(+), 33 deletions(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 582d3fffda..58907a38ab 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -22,8 +22,9 @@ import pickle import sys from abc import ABC, abstractmethod from contextlib import contextmanager +from itertools import chain from os.path import abspath, exists -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np @@ -96,19 +97,50 @@ class DefaultArgumentHandler(ArgumentHandler): Default varargs argument parser handling parameters for each Pipeline """ - def __call__(self, *args, **kwargs): - if "X" in kwargs: - return kwargs["X"] - elif "data" in kwargs: - return kwargs["data"] - elif len(args) == 1: - if isinstance(args[0], list): - return args[0] - else: + @staticmethod + def handle_kwargs(kwargs: Dict) -> List: + if len(kwargs) == 1: + output = list(kwargs.values()) + else: + output = list(chain(kwargs.values())) + + return DefaultArgumentHandler.handle_args(output) + + @staticmethod + def handle_args(args: Sequence[Any]) -> List[str]: + + # Only one argument, let's do case by case + if len(args) == 1: + if isinstance(args[0], str): return [args[0]] + elif not isinstance(args[0], list): + return list(args) + else: + return args[0] + + # Multiple arguments (x1, x2, ...) elif len(args) > 1: - return list(args) - raise ValueError("Unable to infer the format of the provided data (X=, data=, ...)") + if all([isinstance(arg, str) for arg in args]): + return list(args) + + # If not instance of list, then it should instance of iterable + elif isinstance(args, Iterable): + return list(chain.from_iterable(chain(args))) + else: + raise ValueError( + "Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args)) + ) + else: + return [] + + def __call__(self, *args, **kwargs): + if len(kwargs) > 0 and len(args) > 0: + raise ValueError("Pipeline cannot handle mixed args and kwargs") + + if len(kwargs) > 0: + return DefaultArgumentHandler.handle_kwargs(kwargs) + else: + return DefaultArgumentHandler.handle_args(args) class PipelineDataFormat: @@ -418,20 +450,20 @@ class Pipeline(_ScikitCompat): """ return {name: tensor.to(self.device) for name, tensor in inputs.items()} - def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs): + def _parse_and_tokenize(self, *args, pad_to_max_length=True, **kwargs): """ Parse arguments and tokenize """ # Parse arguments - inputs = self._args_parser(*texts, **kwargs) + inputs = self._args_parser(*args, **kwargs) inputs = self.tokenizer.batch_encode_plus( inputs, add_special_tokens=True, return_tensors=self.framework, pad_to_max_length=pad_to_max_length, ) return inputs - def __call__(self, *texts, **kwargs): - inputs = self._parse_and_tokenize(*texts, **kwargs) + def __call__(self, *args, **kwargs): + inputs = self._parse_and_tokenize(*args, **kwargs) return self._forward(inputs) def _forward(self, inputs, return_tensors=False): @@ -550,18 +582,18 @@ class TextGenerationPipeline(Pipeline): with people, even a bishop, begging for his blessing. """ def __call__( - self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs + self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs ): - text_inputs = self._args_parser(*texts) + text_inputs = self._args_parser(*args) results = [] for prompt_text in text_inputs: # Manage correct placement of the tensors with self.device_placement(): if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]: - inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text) + inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text, pad_to_max_length=False) else: - inputs = self._parse_and_tokenize(prompt_text) + inputs = self._parse_and_tokenize(prompt_text, pad_to_max_length=False) # set input_ids to None to allow empty prompt if inputs["input_ids"].shape[-1] == 0: @@ -825,8 +857,8 @@ class NerPipeline(Pipeline): self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self.ignore_labels = ignore_labels - def __call__(self, *texts, **kwargs): - inputs = self._args_parser(*texts, **kwargs) + def __call__(self, *args, **kwargs): + inputs = self._args_parser(*args, **kwargs) answers = [] for sentence in inputs: @@ -1016,7 +1048,7 @@ class QuestionAnsweringPipeline(Pipeline): else: return SquadExample(None, question, context, None, None, None) - def __call__(self, *texts, **kwargs): + def __call__(self, *args, **kwargs): """ Args: We support multiple use-cases, the following are exclusive: @@ -1046,7 +1078,7 @@ class QuestionAnsweringPipeline(Pipeline): raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"])) # Convert inputs to features - examples = self._args_parser(*texts, **kwargs) + examples = self._args_parser(*args, **kwargs) features_list = [ squad_convert_examples_to_features( [example], @@ -1383,11 +1415,11 @@ class TranslationPipeline(Pipeline): """ def __call__( - self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs + self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs ): r""" Args: - *texts: (list of strings) texts to be translated + *args: (list of strings) texts to be translated return_text: (bool, default=True) whether to add a decoded "translation_text" to each result return_tensors: (bool, default=False) whether to return the raw "translation_token_ids" to each result @@ -1402,25 +1434,25 @@ class TranslationPipeline(Pipeline): prefix = self.model.config.prefix if self.model.config.prefix is not None else "" - if isinstance(texts[0], list): + if isinstance(args[0], list): assert ( self.tokenizer.pad_token_id is not None ), "Please make sure that the tokenizer has a pad_token_id when using a batch input" - texts = ([prefix + text for text in texts[0]],) + args = ([prefix + text for text in args[0]],) pad_to_max_length = True - elif isinstance(texts[0], str): - texts = (prefix + texts[0],) + elif isinstance(args[0], str): + args = (prefix + args[0],) pad_to_max_length = False else: raise ValueError( " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format( - texts[0] + args[0] ) ) with self.device_placement(): - inputs = self._parse_and_tokenize(*texts, pad_to_max_length=pad_to_max_length) + inputs = self._parse_and_tokenize(*args, pad_to_max_length=pad_to_max_length) if self.framework == "pt": inputs = self.ensure_tensor_on_device(**inputs) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 0f91b813d7..d0bac672e9 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -2,7 +2,7 @@ import unittest from typing import Iterable, List, Optional from transformers import pipeline -from transformers.pipelines import Pipeline +from transformers.pipelines import DefaultArgumentHandler, Pipeline from .utils import require_tf, require_torch, slow @@ -86,6 +86,78 @@ TRANSLATION_FINETUNED_MODELS = { TF_TRANSLATION_FINETUNED_MODELS = {("patrickvonplaten/t5-tiny-random", "t5-small", "translation_en_to_fr")} +class DefaultArgumentHandlerTestCase(unittest.TestCase): + def setUp(self) -> None: + self.handler = DefaultArgumentHandler() + + def test_kwargs_x(self): + mono_data = {"X": "This is a sample input"} + mono_args = self.handler(**mono_data) + + self.assertTrue(isinstance(mono_args, list)) + self.assertEqual(len(mono_args), 1) + + multi_data = {"x": ["This is a sample input", "This is a second sample input"]} + multi_args = self.handler(**multi_data) + + self.assertTrue(isinstance(multi_args, list)) + self.assertEqual(len(multi_args), 2) + + def test_kwargs_data(self): + mono_data = {"data": "This is a sample input"} + mono_args = self.handler(**mono_data) + + self.assertTrue(isinstance(mono_args, list)) + self.assertEqual(len(mono_args), 1) + + multi_data = {"data": ["This is a sample input", "This is a second sample input"]} + multi_args = self.handler(**multi_data) + + self.assertTrue(isinstance(multi_args, list)) + self.assertEqual(len(multi_args), 2) + + def test_multi_kwargs(self): + mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"} + mono_args = self.handler(**mono_data) + + self.assertTrue(isinstance(mono_args, list)) + self.assertEqual(len(mono_args), 2) + + multi_data = { + "data": ["This is a sample input", "This is a second sample input"], + "test": ["This is a sample input 2", "This is a second sample input 2"], + } + multi_args = self.handler(**multi_data) + + self.assertTrue(isinstance(multi_args, list)) + self.assertEqual(len(multi_args), 4) + + def test_args(self): + mono_data = "This is a sample input" + mono_args = self.handler(mono_data) + + self.assertTrue(isinstance(mono_args, list)) + self.assertEqual(len(mono_args), 1) + + mono_data = ["This is a sample input"] + mono_args = self.handler(mono_data) + + self.assertTrue(isinstance(mono_args, list)) + self.assertEqual(len(mono_args), 1) + + multi_data = ["This is a sample input", "This is a second sample input"] + multi_args = self.handler(multi_data) + + self.assertTrue(isinstance(multi_args, list)) + self.assertEqual(len(multi_args), 2) + + multi_data = ["This is a sample input", "This is a second sample input"] + multi_args = self.handler(*multi_data) + + self.assertTrue(isinstance(multi_args, list)) + self.assertEqual(len(multi_args), 2) + + class MonoColumnInputTestCase(unittest.TestCase): def _test_mono_column_pipeline( self,