From 2b282296f14e9cde3e0c21013a1ac01fbdd4da00 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 19 May 2022 20:28:12 +0200 Subject: [PATCH] Adding `batch_size` test to QA pipeline. (#17330) --- tests/pipelines/test_pipelines_question_answering.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index b775f7b7d3..f34237612c 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -106,6 +106,13 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): ) self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)}) + # Using batch is OK + new_outputs = question_answerer( + question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2 + ) + self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)}) + self.assertEqual(outputs, new_outputs) + @require_torch def test_small_model_pt(self): question_answerer = pipeline(