Adding QoL for batch_size arg (like others enabled everywhere). (#15027)

* Adding QoL for `batch_size` arg (like others enabled everywhere).

* Typo.
This commit is contained in:
Nicolas Patry
2022-01-05 12:16:23 +01:00
committed by GitHub
parent e34dd055e9
commit 65cb94ff77
2 changed files with 25 additions and 1 deletions

View File

@@ -299,6 +299,16 @@ class CommonPipelineTest(unittest.TestCase):
self.assertIsInstance(pipe, TextClassificationPipeline)
@require_torch
def test_pipeline_batch_size_global(self):
pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert")
self.assertEqual(pipe._batch_size, None)
self.assertEqual(pipe._num_workers, None)
pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert", batch_size=2, num_workers=1)
self.assertEqual(pipe._batch_size, 2)
self.assertEqual(pipe._num_workers, 1)
@require_torch
def test_pipeline_override(self):
class MyPipeline(TextClassificationPipeline):