From 65cb94ff773dbba5b41028bf635c6e23e42e8e94 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 5 Jan 2022 12:16:23 +0100 Subject: [PATCH] Adding QoL for `batch_size` arg (like others enabled everywhere). (#15027) * Adding QoL for `batch_size` arg (like others enabled everywhere). * Typo. --- src/transformers/pipelines/base.py | 16 +++++++++++++++- tests/test_pipelines_common.py | 10 ++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index fe98c36530..5445b718e3 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -742,6 +742,8 @@ class Pipeline(_ScikitCompat): self.model.config.update(task_specific_params.get(task)) self.call_count = 0 + self._batch_size = kwargs.pop("batch_size", None) + self._num_workers = kwargs.pop("num_workers", None) self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) def save_pretrained(self, save_directory: str): @@ -947,9 +949,21 @@ class Pipeline(_ScikitCompat): final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) return final_iterator - def __call__(self, inputs, *args, num_workers=0, batch_size=1, **kwargs): + def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs): if args: logger.warning(f"Ignoring args : {args}") + + if num_workers is None: + if self._num_workers is None: + num_workers = 0 + else: + num_workers = self._num_workers + if batch_size is None: + if self._batch_size is None: + batch_size = 1 + else: + batch_size = self._batch_size + preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(**kwargs) # Fuse __init__ params and __call__ params without modifying the __init__ ones. diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index c832f3c891..05fa383ce4 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -299,6 +299,16 @@ class CommonPipelineTest(unittest.TestCase): self.assertIsInstance(pipe, TextClassificationPipeline) + @require_torch + def test_pipeline_batch_size_global(self): + pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert") + self.assertEqual(pipe._batch_size, None) + self.assertEqual(pipe._num_workers, None) + + pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert", batch_size=2, num_workers=1) + self.assertEqual(pipe._batch_size, 2) + self.assertEqual(pipe._num_workers, 1) + @require_torch def test_pipeline_override(self): class MyPipeline(TextClassificationPipeline):