Fix QA argument handler (#8765)
* Fix QA argument handler * Attempt to get a better fix for QA (#8768) Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@@ -1624,7 +1624,17 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
|||||||
elif "data" in kwargs:
|
elif "data" in kwargs:
|
||||||
inputs = kwargs["data"]
|
inputs = kwargs["data"]
|
||||||
elif "question" in kwargs and "context" in kwargs:
|
elif "question" in kwargs and "context" in kwargs:
|
||||||
|
if isinstance(kwargs["question"], list) and isinstance(kwargs["context"], str):
|
||||||
|
inputs = [{"question": Q, "context": kwargs["context"]} for Q in kwargs["question"]]
|
||||||
|
elif isinstance(kwargs["question"], list) and isinstance(kwargs["context"], list):
|
||||||
|
if len(kwargs["question"]) != len(kwargs["context"]):
|
||||||
|
raise ValueError("Questions and contexts don't have the same lengths")
|
||||||
|
|
||||||
|
inputs = [{"question": Q, "context": C} for Q, C in zip(kwargs["question"], kwargs["context"])]
|
||||||
|
elif isinstance(kwargs["question"], str) and isinstance(kwargs["context"], str):
|
||||||
inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
|
inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
|
||||||
|
else:
|
||||||
|
raise ValueError("Arguments can't be understood")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown arguments {}".format(kwargs))
|
raise ValueError("Unknown arguments {}".format(kwargs))
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,17 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
"question": "In what field is HuggingFace working ?",
|
"question": "In what field is HuggingFace working ?",
|
||||||
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
|
||||||
|
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
|
||||||
|
"context": [
|
||||||
|
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||||
|
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||||
|
],
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
def _test_pipeline(self, nlp: Pipeline):
|
def _test_pipeline(self, nlp: Pipeline):
|
||||||
@@ -80,6 +91,11 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
self.assertEqual(len(normalized), 1)
|
self.assertEqual(len(normalized), 1)
|
||||||
self.assertEqual({type(el) for el in normalized}, {SquadExample})
|
self.assertEqual({type(el) for el in normalized}, {SquadExample})
|
||||||
|
|
||||||
|
normalized = qa(question=[Q, Q], context=C)
|
||||||
|
self.assertEqual(type(normalized), list)
|
||||||
|
self.assertEqual(len(normalized), 2)
|
||||||
|
self.assertEqual({type(el) for el in normalized}, {SquadExample})
|
||||||
|
|
||||||
normalized = qa({"question": Q, "context": C})
|
normalized = qa({"question": Q, "context": C})
|
||||||
self.assertEqual(type(normalized), list)
|
self.assertEqual(type(normalized), list)
|
||||||
self.assertEqual(len(normalized), 1)
|
self.assertEqual(len(normalized), 1)
|
||||||
@@ -159,6 +175,26 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
qa([{"question": Q, "context": C}, {"question": Q, "context": ""}])
|
qa([{"question": Q, "context": C}, {"question": Q, "context": ""}])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
qa(question={"This": "Is weird"}, context="This is a context")
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
qa(question=[Q, Q], context=[C, C, C])
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
qa(question=[Q, Q, Q], context=[C, C])
|
||||||
|
|
||||||
|
def test_argument_handler_old_format(self):
|
||||||
|
qa = QuestionAnsweringArgumentHandler()
|
||||||
|
|
||||||
|
Q = "Where was HuggingFace founded ?"
|
||||||
|
C = "HuggingFace was founded in Paris"
|
||||||
|
# Backward compatibility for this
|
||||||
|
normalized = qa(question=[Q, Q], context=[C, C])
|
||||||
|
self.assertEqual(type(normalized), list)
|
||||||
|
self.assertEqual(len(normalized), 2)
|
||||||
|
self.assertEqual({type(el) for el in normalized}, {SquadExample})
|
||||||
|
|
||||||
def test_argument_handler_error_handling_odd(self):
|
def test_argument_handler_error_handling_odd(self):
|
||||||
qa = QuestionAnsweringArgumentHandler()
|
qa = QuestionAnsweringArgumentHandler()
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
|||||||
Reference in New Issue
Block a user