Adding a test to prevent late failure in the Table question answering (#9808)
pipeline. - If table is empty then the line that contain `answer[0]` will fail. - This PR add a check to prevent `answer[0]`. - Also adds an early check for presence of `table` and `query` to prevent late failure and give better error message. - Adds a few tests to make sure these errors are correctly raised.
This commit is contained in:
@@ -131,6 +131,49 @@ class TQAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
self.assertIsInstance(table_querier.model.config.aggregation_labels, dict)
|
||||
self.assertIsInstance(table_querier.model.config.no_aggregation_label_index, int)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
table_querier(
|
||||
{
|
||||
"table": {},
|
||||
"query": "how many movies has george clooney played in?",
|
||||
}
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
table_querier(
|
||||
{
|
||||
"query": "how many movies has george clooney played in?",
|
||||
}
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
table_querier(
|
||||
{
|
||||
"table": {
|
||||
"Repository": ["Transformers", "Datasets", "Tokenizers"],
|
||||
"Stars": ["36542", "4512", "3934"],
|
||||
"Contributors": ["651", "77", "34"],
|
||||
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
|
||||
},
|
||||
"query": "",
|
||||
}
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
table_querier(
|
||||
{
|
||||
"table": {
|
||||
"Repository": ["Transformers", "Datasets", "Tokenizers"],
|
||||
"Stars": ["36542", "4512", "3934"],
|
||||
"Contributors": ["651", "77", "34"],
|
||||
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def test_empty_errors(self):
|
||||
table_querier = pipeline(
|
||||
"table-question-answering",
|
||||
model="lysandre/tiny-tapas-random-wtq",
|
||||
tokenizer="lysandre/tiny-tapas-random-wtq",
|
||||
)
|
||||
mono_result = table_querier(self.valid_inputs[0], sequential=True)
|
||||
multi_result = table_querier(self.valid_inputs, sequential=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user