From 195fbbb6cfc6c6279cef6be12b05a53d589b0de8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 14 Apr 2022 09:06:14 +0200 Subject: [PATCH] Enabling `Tapex` in table question answering pipeline. (#16663) * Enabling `Tapex` in table question answering pipeline. * Questions are independant for Tapex, making the test respect that. * Missing extra space. --- .../pipelines/table_question_answering.py | 85 +++++++++++-------- ...test_pipelines_table_question_answering.py | 28 ++++++ 2 files changed, 79 insertions(+), 34 deletions(-) diff --git a/src/transformers/pipelines/table_question_answering.py b/src/transformers/pipelines/table_question_answering.py index c13753032d..d94bb6d061 100644 --- a/src/transformers/pipelines/table_question_answering.py +++ b/src/transformers/pipelines/table_question_answering.py @@ -105,9 +105,10 @@ class TableQuestionAnsweringPipeline(Pipeline): else MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING ) - self.aggregate = bool(getattr(self.model.config, "aggregation_labels")) and bool( - getattr(self.model.config, "num_aggregation_labels") + self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool( + getattr(self.model.config, "num_aggregation_labels", None) ) + self.type = "tapas" if hasattr(self.model.config, "aggregation_labels") else None def batch_inference(self, **inputs): return self.model(**inputs) @@ -335,7 +336,13 @@ class TableQuestionAnsweringPipeline(Pipeline): forward_params["sequential"] = sequential return preprocess_params, forward_params, {} - def preprocess(self, pipeline_input, sequential=None, padding=True, truncation="drop_rows_to_fit"): + def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None): + if truncation is None: + if self.type == "tapas": + truncation = "drop_rows_to_fit" + else: + truncation = "do_not_truncate" + table, query = pipeline_input["table"], pipeline_input["query"] if table.empty: raise ValueError("table is empty") @@ -347,7 +354,14 @@ class TableQuestionAnsweringPipeline(Pipeline): def _forward(self, model_inputs, sequential=False): table = model_inputs.pop("table") - outputs = self.sequential_inference(**model_inputs) if sequential else self.batch_inference(**model_inputs) + + if self.type == "tapas": + if sequential: + outputs = self.sequential_inference(**model_inputs) + else: + outputs = self.batch_inference(**model_inputs) + else: + outputs = self.model.generate(**model_inputs) model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs} return model_outputs @@ -355,37 +369,40 @@ class TableQuestionAnsweringPipeline(Pipeline): inputs = model_outputs["model_inputs"] table = model_outputs["table"] outputs = model_outputs["outputs"] - if self.aggregate: - logits, logits_agg = outputs[:2] - predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg) - answer_coordinates_batch, agg_predictions = predictions - aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)} + if self.type == "tapas": + if self.aggregate: + logits, logits_agg = outputs[:2] + predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg) + answer_coordinates_batch, agg_predictions = predictions + aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)} - no_agg_label_index = self.model.config.no_aggregation_label_index - aggregators_prefix = { - i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index - } + no_agg_label_index = self.model.config.no_aggregation_label_index + aggregators_prefix = { + i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index + } + else: + logits = outputs[0] + predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits) + answer_coordinates_batch = predictions[0] + aggregators = {} + aggregators_prefix = {} + answers = [] + for index, coordinates in enumerate(answer_coordinates_batch): + cells = [table.iat[coordinate] for coordinate in coordinates] + aggregator = aggregators.get(index, "") + aggregator_prefix = aggregators_prefix.get(index, "") + answer = { + "answer": aggregator_prefix + ", ".join(cells), + "coordinates": coordinates, + "cells": [table.iat[coordinate] for coordinate in coordinates], + } + if aggregator: + answer["aggregator"] = aggregator + + answers.append(answer) + if len(answer) == 0: + raise PipelineException("Empty answer") else: - logits = outputs[0] - predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits) - answer_coordinates_batch = predictions[0] - aggregators = {} - aggregators_prefix = {} + answers = [{"answer": answer} for answer in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)] - answers = [] - for index, coordinates in enumerate(answer_coordinates_batch): - cells = [table.iat[coordinate] for coordinate in coordinates] - aggregator = aggregators.get(index, "") - aggregator_prefix = aggregators_prefix.get(index, "") - answer = { - "answer": aggregator_prefix + ", ".join(cells), - "coordinates": coordinates, - "cells": [table.iat[coordinate] for coordinate in coordinates], - } - if aggregator: - answer["aggregator"] = aggregator - - answers.append(answer) - if len(answer) == 0: - raise PipelineException("Empty answer") return answers if len(answers) > 1 else answers[0] diff --git a/tests/pipelines/test_pipelines_table_question_answering.py b/tests/pipelines/test_pipelines_table_question_answering.py index 0793d6586c..86bbf991b0 100644 --- a/tests/pipelines/test_pipelines_table_question_answering.py +++ b/tests/pipelines/test_pipelines_table_question_answering.py @@ -632,3 +632,31 @@ class TQAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): {"answer": "28 november 1967", "coordinates": [(2, 3)], "cells": ["28 november 1967"]}, ] self.assertListEqual(results, expected_results) + + @slow + @require_torch + def test_large_model_pt_tapex(self): + model_id = "microsoft/tapex-large-finetuned-wtq" + table_querier = pipeline( + "table-question-answering", + model=model_id, + ) + data = { + "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + "Age": ["56", "45", "59"], + "Number of movies": ["87", "53", "69"], + "Date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"], + } + queries = [ + "How many movies has George Clooney played in?", + "How old is Mr Clooney ?", + "What's the date of birth of Leonardo ?", + ] + results = table_querier(data, queries, sequential=True) + + expected_results = [ + {"answer": " 69"}, + {"answer": " 59"}, + {"answer": " 10 june 1996"}, + ] + self.assertListEqual(results, expected_results)