From 285c6262a84490270d2f1a1c06ee9ccfc1b60e8f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 27 Jan 2021 10:10:53 +0100 Subject: [PATCH] Adding a test to prevent late failure in the Table question answering (#9808) pipeline. - If table is empty then the line that contain `answer[0]` will fail. - This PR add a check to prevent `answer[0]`. - Also adds an early check for presence of `table` and `query` to prevent late failure and give better error message. - Adds a few tests to make sure these errors are correctly raised. --- .../pipelines/table_question_answering.py | 8 +++- ...test_pipelines_table_question_answering.py | 43 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/table_question_answering.py b/src/transformers/pipelines/table_question_answering.py index 7039c51621..865941f249 100644 --- a/src/transformers/pipelines/table_question_answering.py +++ b/src/transformers/pipelines/table_question_answering.py @@ -3,7 +3,7 @@ import collections import numpy as np from ..file_utils import add_end_docstrings, is_torch_available, requires_pandas -from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline +from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline, PipelineException if is_torch_available(): @@ -239,6 +239,10 @@ class TableQuestionAnsweringPipeline(Pipeline): batched_answers = [] for pipeline_input in pipeline_inputs: table, query = pipeline_input["table"], pipeline_input["query"] + if table.empty: + raise ValueError("table is empty") + if not query: + raise ValueError("query is empty") inputs = self.tokenizer( table, query, return_tensors=self.framework, truncation="drop_rows_to_fit", padding=padding ) @@ -276,5 +280,7 @@ class TableQuestionAnsweringPipeline(Pipeline): answer["aggregator"] = aggregator answers.append(answer) + if len(answer) == 0: + raise PipelineException("Empty answer") batched_answers.append(answers if len(answers) > 1 else answers[0]) return batched_answers if len(batched_answers) > 1 else batched_answers[0] diff --git a/tests/test_pipelines_table_question_answering.py b/tests/test_pipelines_table_question_answering.py index 1856d046ed..8b95f35175 100644 --- a/tests/test_pipelines_table_question_answering.py +++ b/tests/test_pipelines_table_question_answering.py @@ -131,6 +131,49 @@ class TQAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): self.assertIsInstance(table_querier.model.config.aggregation_labels, dict) self.assertIsInstance(table_querier.model.config.no_aggregation_label_index, int) + with self.assertRaises(ValueError): + table_querier( + { + "table": {}, + "query": "how many movies has george clooney played in?", + } + ) + with self.assertRaises(ValueError): + table_querier( + { + "query": "how many movies has george clooney played in?", + } + ) + with self.assertRaises(ValueError): + table_querier( + { + "table": { + "Repository": ["Transformers", "Datasets", "Tokenizers"], + "Stars": ["36542", "4512", "3934"], + "Contributors": ["651", "77", "34"], + "Programming language": ["Python", "Python", "Rust, Python and NodeJS"], + }, + "query": "", + } + ) + with self.assertRaises(ValueError): + table_querier( + { + "table": { + "Repository": ["Transformers", "Datasets", "Tokenizers"], + "Stars": ["36542", "4512", "3934"], + "Contributors": ["651", "77", "34"], + "Programming language": ["Python", "Python", "Rust, Python and NodeJS"], + }, + } + ) + + def test_empty_errors(self): + table_querier = pipeline( + "table-question-answering", + model="lysandre/tiny-tapas-random-wtq", + tokenizer="lysandre/tiny-tapas-random-wtq", + ) mono_result = table_querier(self.valid_inputs[0], sequential=True) multi_result = table_querier(self.valid_inputs, sequential=True)