Fix QA argument handler (#8765)

* Fix QA argument handler

* Attempt to get a better fix for QA (#8768)

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Lysandre Debut
2020-11-25 14:02:15 -05:00
committed by GitHub
parent 4821ea5aeb
commit 138f45c184
2 changed files with 47 additions and 1 deletions

View File

@@ -23,6 +23,17 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
"question": "In what field is HuggingFace working ?",
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
},
{
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
},
{
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
"context": [
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
],
},
]
def _test_pipeline(self, nlp: Pipeline):
@@ -80,6 +91,11 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa(question=[Q, Q], context=C)
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 2)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa({"question": Q, "context": C})
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
@@ -159,6 +175,26 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
with self.assertRaises(ValueError):
qa([{"question": Q, "context": C}, {"question": Q, "context": ""}])
with self.assertRaises(ValueError):
qa(question={"This": "Is weird"}, context="This is a context")
with self.assertRaises(ValueError):
qa(question=[Q, Q], context=[C, C, C])
with self.assertRaises(ValueError):
qa(question=[Q, Q, Q], context=[C, C])
def test_argument_handler_old_format(self):
qa = QuestionAnsweringArgumentHandler()
Q = "Where was HuggingFace founded ?"
C = "HuggingFace was founded in Paris"
# Backward compatibility for this
normalized = qa(question=[Q, Q], context=[C, C])
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 2)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
def test_argument_handler_error_handling_odd(self):
qa = QuestionAnsweringArgumentHandler()
with self.assertRaises(ValueError):