Enabling Tapex in table question answering pipeline. (#16663)
* Enabling `Tapex` in table question answering pipeline. * Questions are independant for Tapex, making the test respect that. * Missing extra space.
This commit is contained in:
@@ -105,9 +105,10 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||||||
else MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
|
else MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
|
||||||
)
|
)
|
||||||
|
|
||||||
self.aggregate = bool(getattr(self.model.config, "aggregation_labels")) and bool(
|
self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool(
|
||||||
getattr(self.model.config, "num_aggregation_labels")
|
getattr(self.model.config, "num_aggregation_labels", None)
|
||||||
)
|
)
|
||||||
|
self.type = "tapas" if hasattr(self.model.config, "aggregation_labels") else None
|
||||||
|
|
||||||
def batch_inference(self, **inputs):
|
def batch_inference(self, **inputs):
|
||||||
return self.model(**inputs)
|
return self.model(**inputs)
|
||||||
@@ -335,7 +336,13 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||||||
forward_params["sequential"] = sequential
|
forward_params["sequential"] = sequential
|
||||||
return preprocess_params, forward_params, {}
|
return preprocess_params, forward_params, {}
|
||||||
|
|
||||||
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation="drop_rows_to_fit"):
|
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):
|
||||||
|
if truncation is None:
|
||||||
|
if self.type == "tapas":
|
||||||
|
truncation = "drop_rows_to_fit"
|
||||||
|
else:
|
||||||
|
truncation = "do_not_truncate"
|
||||||
|
|
||||||
table, query = pipeline_input["table"], pipeline_input["query"]
|
table, query = pipeline_input["table"], pipeline_input["query"]
|
||||||
if table.empty:
|
if table.empty:
|
||||||
raise ValueError("table is empty")
|
raise ValueError("table is empty")
|
||||||
@@ -347,7 +354,14 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||||||
|
|
||||||
def _forward(self, model_inputs, sequential=False):
|
def _forward(self, model_inputs, sequential=False):
|
||||||
table = model_inputs.pop("table")
|
table = model_inputs.pop("table")
|
||||||
outputs = self.sequential_inference(**model_inputs) if sequential else self.batch_inference(**model_inputs)
|
|
||||||
|
if self.type == "tapas":
|
||||||
|
if sequential:
|
||||||
|
outputs = self.sequential_inference(**model_inputs)
|
||||||
|
else:
|
||||||
|
outputs = self.batch_inference(**model_inputs)
|
||||||
|
else:
|
||||||
|
outputs = self.model.generate(**model_inputs)
|
||||||
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
|
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
|
||||||
return model_outputs
|
return model_outputs
|
||||||
|
|
||||||
@@ -355,37 +369,40 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||||||
inputs = model_outputs["model_inputs"]
|
inputs = model_outputs["model_inputs"]
|
||||||
table = model_outputs["table"]
|
table = model_outputs["table"]
|
||||||
outputs = model_outputs["outputs"]
|
outputs = model_outputs["outputs"]
|
||||||
if self.aggregate:
|
if self.type == "tapas":
|
||||||
logits, logits_agg = outputs[:2]
|
if self.aggregate:
|
||||||
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg)
|
logits, logits_agg = outputs[:2]
|
||||||
answer_coordinates_batch, agg_predictions = predictions
|
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg)
|
||||||
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}
|
answer_coordinates_batch, agg_predictions = predictions
|
||||||
|
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}
|
||||||
|
|
||||||
no_agg_label_index = self.model.config.no_aggregation_label_index
|
no_agg_label_index = self.model.config.no_aggregation_label_index
|
||||||
aggregators_prefix = {
|
aggregators_prefix = {
|
||||||
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
|
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
logits = outputs[0]
|
||||||
|
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits)
|
||||||
|
answer_coordinates_batch = predictions[0]
|
||||||
|
aggregators = {}
|
||||||
|
aggregators_prefix = {}
|
||||||
|
answers = []
|
||||||
|
for index, coordinates in enumerate(answer_coordinates_batch):
|
||||||
|
cells = [table.iat[coordinate] for coordinate in coordinates]
|
||||||
|
aggregator = aggregators.get(index, "")
|
||||||
|
aggregator_prefix = aggregators_prefix.get(index, "")
|
||||||
|
answer = {
|
||||||
|
"answer": aggregator_prefix + ", ".join(cells),
|
||||||
|
"coordinates": coordinates,
|
||||||
|
"cells": [table.iat[coordinate] for coordinate in coordinates],
|
||||||
|
}
|
||||||
|
if aggregator:
|
||||||
|
answer["aggregator"] = aggregator
|
||||||
|
|
||||||
|
answers.append(answer)
|
||||||
|
if len(answer) == 0:
|
||||||
|
raise PipelineException("Empty answer")
|
||||||
else:
|
else:
|
||||||
logits = outputs[0]
|
answers = [{"answer": answer} for answer in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)]
|
||||||
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits)
|
|
||||||
answer_coordinates_batch = predictions[0]
|
|
||||||
aggregators = {}
|
|
||||||
aggregators_prefix = {}
|
|
||||||
|
|
||||||
answers = []
|
|
||||||
for index, coordinates in enumerate(answer_coordinates_batch):
|
|
||||||
cells = [table.iat[coordinate] for coordinate in coordinates]
|
|
||||||
aggregator = aggregators.get(index, "")
|
|
||||||
aggregator_prefix = aggregators_prefix.get(index, "")
|
|
||||||
answer = {
|
|
||||||
"answer": aggregator_prefix + ", ".join(cells),
|
|
||||||
"coordinates": coordinates,
|
|
||||||
"cells": [table.iat[coordinate] for coordinate in coordinates],
|
|
||||||
}
|
|
||||||
if aggregator:
|
|
||||||
answer["aggregator"] = aggregator
|
|
||||||
|
|
||||||
answers.append(answer)
|
|
||||||
if len(answer) == 0:
|
|
||||||
raise PipelineException("Empty answer")
|
|
||||||
return answers if len(answers) > 1 else answers[0]
|
return answers if len(answers) > 1 else answers[0]
|
||||||
|
|||||||
@@ -632,3 +632,31 @@ class TQAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
|||||||
{"answer": "28 november 1967", "coordinates": [(2, 3)], "cells": ["28 november 1967"]},
|
{"answer": "28 november 1967", "coordinates": [(2, 3)], "cells": ["28 november 1967"]},
|
||||||
]
|
]
|
||||||
self.assertListEqual(results, expected_results)
|
self.assertListEqual(results, expected_results)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_large_model_pt_tapex(self):
|
||||||
|
model_id = "microsoft/tapex-large-finetuned-wtq"
|
||||||
|
table_querier = pipeline(
|
||||||
|
"table-question-answering",
|
||||||
|
model=model_id,
|
||||||
|
)
|
||||||
|
data = {
|
||||||
|
"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
|
||||||
|
"Age": ["56", "45", "59"],
|
||||||
|
"Number of movies": ["87", "53", "69"],
|
||||||
|
"Date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
|
||||||
|
}
|
||||||
|
queries = [
|
||||||
|
"How many movies has George Clooney played in?",
|
||||||
|
"How old is Mr Clooney ?",
|
||||||
|
"What's the date of birth of Leonardo ?",
|
||||||
|
]
|
||||||
|
results = table_querier(data, queries, sequential=True)
|
||||||
|
|
||||||
|
expected_results = [
|
||||||
|
{"answer": " 69"},
|
||||||
|
{"answer": " 59"},
|
||||||
|
{"answer": " 10 june 1996"},
|
||||||
|
]
|
||||||
|
self.assertListEqual(results, expected_results)
|
||||||
|
|||||||
Reference in New Issue
Block a user