Adding QoL for batch_size arg (like others enabled everywhere). (#15027)
* Adding QoL for `batch_size` arg (like others enabled everywhere). * Typo.
This commit is contained in:
@@ -299,6 +299,16 @@ class CommonPipelineTest(unittest.TestCase):
|
||||
|
||||
self.assertIsInstance(pipe, TextClassificationPipeline)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_batch_size_global(self):
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert")
|
||||
self.assertEqual(pipe._batch_size, None)
|
||||
self.assertEqual(pipe._num_workers, None)
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert", batch_size=2, num_workers=1)
|
||||
self.assertEqual(pipe._batch_size, 2)
|
||||
self.assertEqual(pipe._num_workers, 1)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_override(self):
|
||||
class MyPipeline(TextClassificationPipeline):
|
||||
|
||||
Reference in New Issue
Block a user