Adding batch_size test to QA pipeline. (#17330)

This commit is contained in:
Nicolas Patry
2022-05-19 20:28:12 +02:00
committed by GitHub
parent a4386d7e40
commit 2b282296f1

View File

@@ -106,6 +106,13 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
) )
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)}) self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
# Using batch is OK
new_outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2
)
self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
self.assertEqual(outputs, new_outputs)
@require_torch @require_torch
def test_small_model_pt(self): def test_small_model_pt(self):
question_answerer = pipeline( question_answerer = pipeline(