Rewritten batch support in pipelines. (#4154)
* Rewritten batch support in pipelines. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Fix imports sorting 🔧 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Set pad_to_max_length=True by default on Pipeline. * Set pad_to_max_length=False for generation pipelines. Most of generation models doesn't have padding token. * Address @joeddav review comment: Uniformized *args. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Address @joeddav review comment: Uniformized *args (second). Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import Pipeline
|
||||
from transformers.pipelines import DefaultArgumentHandler, Pipeline
|
||||
|
||||
from .utils import require_tf, require_torch, slow
|
||||
|
||||
@@ -86,6 +86,78 @@ TRANSLATION_FINETUNED_MODELS = {
|
||||
TF_TRANSLATION_FINETUNED_MODELS = {("patrickvonplaten/t5-tiny-random", "t5-small", "translation_en_to_fr")}
|
||||
|
||||
|
||||
class DefaultArgumentHandlerTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.handler = DefaultArgumentHandler()
|
||||
|
||||
def test_kwargs_x(self):
|
||||
mono_data = {"X": "This is a sample input"}
|
||||
mono_args = self.handler(**mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
|
||||
multi_args = self.handler(**multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
|
||||
def test_kwargs_data(self):
|
||||
mono_data = {"data": "This is a sample input"}
|
||||
mono_args = self.handler(**mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
|
||||
multi_args = self.handler(**multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
|
||||
def test_multi_kwargs(self):
|
||||
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
|
||||
mono_args = self.handler(**mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 2)
|
||||
|
||||
multi_data = {
|
||||
"data": ["This is a sample input", "This is a second sample input"],
|
||||
"test": ["This is a sample input 2", "This is a second sample input 2"],
|
||||
}
|
||||
multi_args = self.handler(**multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 4)
|
||||
|
||||
def test_args(self):
|
||||
mono_data = "This is a sample input"
|
||||
mono_args = self.handler(mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
mono_data = ["This is a sample input"]
|
||||
mono_args = self.handler(mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
multi_data = ["This is a sample input", "This is a second sample input"]
|
||||
multi_args = self.handler(multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
|
||||
multi_data = ["This is a sample input", "This is a second sample input"]
|
||||
multi_args = self.handler(*multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
|
||||
|
||||
class MonoColumnInputTestCase(unittest.TestCase):
|
||||
def _test_mono_column_pipeline(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user