Adding batch_size test to QA pipeline. (#17330)
This commit is contained in:
@@ -106,6 +106,13 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
|||||||
)
|
)
|
||||||
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||||
|
|
||||||
|
# Using batch is OK
|
||||||
|
new_outputs = question_answerer(
|
||||||
|
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20, batch_size=2
|
||||||
|
)
|
||||||
|
self.assertEqual(new_outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||||
|
self.assertEqual(outputs, new_outputs)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt(self):
|
def test_small_model_pt(self):
|
||||||
question_answerer = pipeline(
|
question_answerer = pipeline(
|
||||||
|
|||||||
Reference in New Issue
Block a user