Adding batch_size test to QA pipeline. (#17330)
This commit is contained in:
@@ -106,6 +106,13 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
)
|
||||
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||
|
||||
# Using batch is OK
|
||||
new_outputs = question_answerer(
|
||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2
|
||||
)
|
||||
self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||
self.assertEqual(outputs, new_outputs)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
question_answerer = pipeline(
|
||||
|
||||
Reference in New Issue
Block a user