Adding a test to prevent late failure in the Table question answering (#9808)
pipeline. - If table is empty then the line that contain `answer[0]` will fail. - This PR add a check to prevent `answer[0]`. - Also adds an early check for presence of `table` and `query` to prevent late failure and give better error message. - Adds a few tests to make sure these errors are correctly raised.
This commit is contained in:
@@ -3,7 +3,7 @@ import collections
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..file_utils import add_end_docstrings, is_torch_available, requires_pandas
|
from ..file_utils import add_end_docstrings, is_torch_available, requires_pandas
|
||||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
|
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline, PipelineException
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -239,6 +239,10 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||||||
batched_answers = []
|
batched_answers = []
|
||||||
for pipeline_input in pipeline_inputs:
|
for pipeline_input in pipeline_inputs:
|
||||||
table, query = pipeline_input["table"], pipeline_input["query"]
|
table, query = pipeline_input["table"], pipeline_input["query"]
|
||||||
|
if table.empty:
|
||||||
|
raise ValueError("table is empty")
|
||||||
|
if not query:
|
||||||
|
raise ValueError("query is empty")
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
table, query, return_tensors=self.framework, truncation="drop_rows_to_fit", padding=padding
|
table, query, return_tensors=self.framework, truncation="drop_rows_to_fit", padding=padding
|
||||||
)
|
)
|
||||||
@@ -276,5 +280,7 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||||||
answer["aggregator"] = aggregator
|
answer["aggregator"] = aggregator
|
||||||
|
|
||||||
answers.append(answer)
|
answers.append(answer)
|
||||||
|
if len(answer) == 0:
|
||||||
|
raise PipelineException("Empty answer")
|
||||||
batched_answers.append(answers if len(answers) > 1 else answers[0])
|
batched_answers.append(answers if len(answers) > 1 else answers[0])
|
||||||
return batched_answers if len(batched_answers) > 1 else batched_answers[0]
|
return batched_answers if len(batched_answers) > 1 else batched_answers[0]
|
||||||
|
|||||||
@@ -131,6 +131,49 @@ class TQAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
self.assertIsInstance(table_querier.model.config.aggregation_labels, dict)
|
self.assertIsInstance(table_querier.model.config.aggregation_labels, dict)
|
||||||
self.assertIsInstance(table_querier.model.config.no_aggregation_label_index, int)
|
self.assertIsInstance(table_querier.model.config.no_aggregation_label_index, int)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
table_querier(
|
||||||
|
{
|
||||||
|
"table": {},
|
||||||
|
"query": "how many movies has george clooney played in?",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
table_querier(
|
||||||
|
{
|
||||||
|
"query": "how many movies has george clooney played in?",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
table_querier(
|
||||||
|
{
|
||||||
|
"table": {
|
||||||
|
"Repository": ["Transformers", "Datasets", "Tokenizers"],
|
||||||
|
"Stars": ["36542", "4512", "3934"],
|
||||||
|
"Contributors": ["651", "77", "34"],
|
||||||
|
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
|
||||||
|
},
|
||||||
|
"query": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
table_querier(
|
||||||
|
{
|
||||||
|
"table": {
|
||||||
|
"Repository": ["Transformers", "Datasets", "Tokenizers"],
|
||||||
|
"Stars": ["36542", "4512", "3934"],
|
||||||
|
"Contributors": ["651", "77", "34"],
|
||||||
|
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_empty_errors(self):
|
||||||
|
table_querier = pipeline(
|
||||||
|
"table-question-answering",
|
||||||
|
model="lysandre/tiny-tapas-random-wtq",
|
||||||
|
tokenizer="lysandre/tiny-tapas-random-wtq",
|
||||||
|
)
|
||||||
mono_result = table_querier(self.valid_inputs[0], sequential=True)
|
mono_result = table_querier(self.valid_inputs[0], sequential=True)
|
||||||
multi_result = table_querier(self.valid_inputs, sequential=True)
|
multi_result = table_querier(self.valid_inputs, sequential=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user