Add TAPEX (#16473)
* Add TapexTokenizer * Improve docstrings and provide option to provide answer * Remove option for pretokenized inputs * Add TAPEX to README * Fix copies * Remove option for pretokenized inputs * Initial commit: add tapex fine-tuning examples on both table-based question answering and table-based fact verification. * - Draft a README file for running the script and introducing some background. - Remove unused code lines in tabfact script. - Disable the deafult `pad_to_max_length` option which is memory-consuming. * * Support `as_target_tokenizer` function for TapexTokenizer. * Fix the do_lower_case behaviour of TapexTokenizer. * Add unit tests for target scenarios and cased/uncased scenarios for both source and target. * * Replace the label BartTokenizer with TapexTokenizer's as_target_tokenizer function. * Fix typos in tapex example README. * * fix the evaluation script - remove the property `task_name` * * Make the label space more clear for tabfact tasks * * Using a new fine-tuning script for tapex-base on tabfact. * * Remove the lowercase code outside the tokenizer - we use the tokenizer to control whether do_lower_case * Guarantee the hyper-parameter can be run without out-of-memory on 16GB card and report the new reproduced number on wikisql * * Remove the default tokenizer_name option. * Provide evaluation command. * * Support for WikiTableQuestion dataset. * Fix a typo in README. * * Fix the datasets's key name in WikiTableQuestions * Run make fixup and move test to folder * Fix quality * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply some more suggestions from code review * Improve docstrings * Overwrite failing test * Improve comment in example scripts * Fix rebase * Add TAPEX to Auto mapping * Add TAPEX to auto config mappings * Put TAPEX higher than BART in auto mapping * Add TAPEX to doc tests Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain> Co-authored-by: SivilTaram <qianlxc@outlook.com> Co-authored-by: Niels Rogge <nielsrogge@nielss-mbp.home> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
288
examples/research_projects/tapex/README.md
Normal file
288
examples/research_projects/tapex/README.md
Normal file
@@ -0,0 +1,288 @@
|
||||
<!---
|
||||
Copyright 2022 The Microsoft Inc. and The HuggingFace Inc. Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Run Table Tasks with TAPEX
|
||||
|
||||
TAPEX is a table pre-training approach for table-related tasks. By learning a neural SQL executor over a synthetic corpus based on generative language models (e.g., BART), it achieves state-of-the-art performance on several table-based question answering benchmarks and table-based fact verification benchmark. More details can be found in the original paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/pdf/2107.07653.pdf).
|
||||
|
||||
> If you are also familiar with [fairseq](https://github.com/pytorch/fairseq), you may also find [the official implementation](https://github.com/microsoft/Table-Pretraining) useful, which leverages the framework.
|
||||
|
||||
## Table Question Answering Tasks
|
||||
|
||||
### What is Table Question Answering
|
||||
|
||||

|
||||
|
||||
The task of Table Question Answering (TableQA) is to empower machines to answer users' questions over a given table. The resulting answer(s) can be a region in the table, or a number calculated by applying aggregation operators to a specific region.
|
||||
|
||||
### What Questions Can be Answered
|
||||
|
||||
Benefiting from the powerfulness of generative models, TAPEX can deal with almost all kinds of questions over tables (if there is training data). Below are some typical question and their answers taken from [WikiTableQuestion](https://nlp.stanford.edu/blog/wikitablequestions-a-complex-real-world-question-understanding-dataset).
|
||||
|
||||
| Question | Answer |
|
||||
| :---: | :---: |
|
||||
| What is the years won for each team? | 2004, 2008, 2012 |
|
||||
| How long did Taiki Tsuchiya last? | 4:27 |
|
||||
| What is the total amount of matches drawn? | 1 |
|
||||
| Besides Tiger Woods, what other player won between 2007 and 2009? | Camilo Villegas |
|
||||
| What was the last Baekje Temple? | Uija |
|
||||
| What is the difference between White voters and Black voters in 1948? | 0 |
|
||||
| What is the average number of sailors for each country during the worlds qualification tournament? | 2 |
|
||||
|
||||
|
||||
### How to Fine-tune TAPEX on TableQA
|
||||
|
||||
We provide a fine-tuning script of tapex for TableQA on the WikiSQL benchmark: [WikiSQL](https://github.com/salesforce/WikiSQL).
|
||||
This script is customized for tapex models, and can be easily adapted to other benchmarks such as WikiTableQuestion
|
||||
(only some tweaks in the function `preprocess_tableqa_function`).
|
||||
|
||||
#### TAPEX-Base on WikiSQL
|
||||
|
||||
Here is how to run the script on the WikiSQL with `tapex-base`:
|
||||
> The default hyper-parameter may allow you to reproduce our reported tapex-base results within the memory budget of 16GB and 1 GPU card. If you have more GPU cards, you could reduce `gradient_accumulation_steps` accordingly.
|
||||
|
||||
```bash
|
||||
export EXP_NAME=wikisql_tapex_base
|
||||
|
||||
python run_wikisql_with_tapex.py \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--output_dir $EXP_NAME \
|
||||
--model_name_or_path microsoft/tapex-base \
|
||||
--overwrite_output_dir \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--learning_rate 3e-5 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 1000 \
|
||||
--save_steps 1000 \
|
||||
--warmup_steps 1000 \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate \
|
||||
--num_beams 5 \
|
||||
--weight_decay 1e-2 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--max_steps 20000
|
||||
```
|
||||
|
||||
#### TAPEX-Large on WikiSQL
|
||||
|
||||
Here is how to run the script on the WikiSQL with `tapex-large`:
|
||||
> The default hyper-parameter may allow you to reproduce our reported tapex-large results within the memory budget of 16GB and 1 GPU card with fp16. If you have more GPU cards, you could reduce `gradient_accumulation_steps` accordingly. If you do not install apex or other mixed-precision-training libs, you could disable the `predict_with_generate` option to save GPU memory and manually evaluate the model once the fine-tuning finished. Or just pick up the last checkpoint, which usually performs good enough on the dataset.
|
||||
|
||||
```bash
|
||||
export EXP_NAME=wikisql_tapex_large
|
||||
|
||||
python run_wikisql_with_tapex.py \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--output_dir $EXP_NAME \
|
||||
--model_name_or_path microsoft/tapex-large \
|
||||
--overwrite_output_dir \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 32 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--learning_rate 3e-5 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 1000 \
|
||||
--save_steps 1000 \
|
||||
--warmup_steps 1000 \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate \
|
||||
--num_beams 5 \
|
||||
--weight_decay 1e-2 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--max_steps 20000 \
|
||||
--fp16
|
||||
```
|
||||
|
||||
#### TAPEX-Base on WikiTableQuestions
|
||||
|
||||
Here is how to run the script on the WikiTableQuestions with `tapex-base`:
|
||||
> The default hyper-parameter may allow you to reproduce our reported tapex-base results within the memory budget of 16GB and 1 GPU card. If you have more GPU cards, you could reduce `gradient_accumulation_steps` accordingly.
|
||||
|
||||
```bash
|
||||
export EXP_NAME=wikitablequestions_tapex_base
|
||||
|
||||
python run_wikitablequestions_with_tapex.py \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--output_dir $EXP_NAME \
|
||||
--model_name_or_path microsoft/tapex-base \
|
||||
--overwrite_output_dir \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--learning_rate 3e-5 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 1000 \
|
||||
--save_steps 1000 \
|
||||
--warmup_steps 1000 \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate \
|
||||
--num_beams 5 \
|
||||
--weight_decay 1e-2 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--max_steps 20000
|
||||
```
|
||||
|
||||
#### TAPEX-Large on WikiTableQuestions
|
||||
|
||||
Here is how to run the script on the WikiTableQuestions with `tapex-large`:
|
||||
> The default hyper-parameter may allow you to reproduce our reported tapex-large results within the memory budget of 16GB and 1 GPU card with fp16. If you have more GPU cards, you could reduce `gradient_accumulation_steps` accordingly. If you do not install apex or other mixed-precision-training libs, you could reduce the `per_device_train_batch_size` and `per_device_eval_batch_size` and have another try. Or you could disable the `predict_with_generate` option to save GPU memory and manually evaluate the model once the fine-tuning finished. Or just pick up the last checkpoint, which usually performs good enough on the dataset.
|
||||
|
||||
```bash
|
||||
export EXP_NAME=wikitablequestions_tapex_large
|
||||
|
||||
python run_wikitablequestions_with_tapex.py \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--output_dir $EXP_NAME \
|
||||
--model_name_or_path microsoft/tapex-large \
|
||||
--overwrite_output_dir \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 12 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--learning_rate 3e-5 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 1000 \
|
||||
--save_steps 1000 \
|
||||
--warmup_steps 1000 \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate \
|
||||
--num_beams 5 \
|
||||
--weight_decay 1e-2 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--max_steps 20000 \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### How to Evaluate TAPEX Fine-tuned Models on TableQA
|
||||
|
||||
We provide fine-tuned model weights to reproduce our results. You can evaluate them using the following command:
|
||||
> You can also replace `microsoft/tapex-base-finetuned-wikisql` with your local directory to evaluate your fine-tuned models. Notice that if the model has a larger size, you should reduce `per_device_eval_batch_size` to fit the memory requirement.
|
||||
|
||||
```bash
|
||||
export EXP_NAME=wikisql_tapex_base_eval
|
||||
|
||||
python run_wikisql_with_tapex.py \
|
||||
--do_eval \
|
||||
--model_name_or_path microsoft/tapex-base-finetuned-wikisql \
|
||||
--output_dir $EXP_NAME \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--predict_with_generate \
|
||||
--num_beams 5
|
||||
```
|
||||
|
||||
## Table Fact Verification Tasks
|
||||
|
||||
### What is Table Fact Verification
|
||||
|
||||

|
||||
|
||||
The task of Table Fact Verification (TableFV) is to empower machines to justify if a statement follows facts in a given table. The result is a binary classification belonging to `1` (entailed) or `0` (refused).
|
||||
|
||||
### How to Fine-tune TAPEX on TableFV
|
||||
|
||||
#### TAPEX-Base on TabFact
|
||||
|
||||
We provide a fine-tuning script of tapex for TableFV on the TabFact benchmark: [TabFact](https://github.com/wenhuchen/Table-Fact-Checking).
|
||||
|
||||
Here is how to run the script on the TabFact:
|
||||
> The default hyper-parameter may allow you to reproduce our reported tapex-base results within the memory budget of 16GB and 1 GPU card. If you have more GPU cards, you could reduce `gradient_accumulation_steps` accordingly. Note that the `eval_accumulation_steps` is necessary, otherwise GPU memory leaks will occur during the evaluation.
|
||||
|
||||
```bash
|
||||
export EXP_NAME=tabfact_tapex_base
|
||||
|
||||
python run_tabfact_with_tapex.py \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--output_dir $EXP_NAME \
|
||||
--model_name_or_path microsoft/tapex-base \
|
||||
--overwrite_output_dir \
|
||||
--per_device_train_batch_size 3 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--per_device_eval_batch_size 12 \
|
||||
--eval_accumulation_steps 6 \
|
||||
--warm_steps 1000 \
|
||||
--logging_steps 10 \
|
||||
--learning_rate 3e-5 \
|
||||
--eval_steps 1000 \
|
||||
--save_steps 1000 \
|
||||
--evaluation_strategy steps \
|
||||
--weight_decay 1e-2 \
|
||||
--max_steps 30000 \
|
||||
--max_grad_norm 0.1
|
||||
```
|
||||
|
||||
#### TAPEX-Large on TabFact
|
||||
|
||||
Here is how to run the script on the TabFact:
|
||||
> The default hyper-parameter may allow you to reproduce our reported tapex-base results within the memory budget of 24GB and 1 GPU card. Sorry we cannot reduce the memory consumption since the model input in TabFact usually contains nearly ~1000 tokens. If you have more GPU cards, you could reduce `gradient_accumulation_steps` accordingly. Note that the `eval_accumulation_steps` is necessary, otherwise GPU memory leaks will occur during the evaluation.
|
||||
|
||||
```bash
|
||||
export EXP_NAME=tabfact_tapex_large
|
||||
|
||||
python run_tabfact_with_tapex.py \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--output_dir $EXP_NAME \
|
||||
--model_name_or_path microsoft/tapex-large \
|
||||
--overwrite_output_dir \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 18 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--eval_accumulation_steps 12 \
|
||||
--warm_steps 1000 \
|
||||
--logging_steps 10 \
|
||||
--learning_rate 3e-5 \
|
||||
--eval_steps 1000 \
|
||||
--save_steps 1000 \
|
||||
--evaluation_strategy steps \
|
||||
--weight_decay 1e-2 \
|
||||
--max_steps 30000 \
|
||||
--max_grad_norm 0.1
|
||||
```
|
||||
|
||||
### How to Evaluate TAPEX Fine-tuned Models on TableFV
|
||||
|
||||
We provide fine-tuned model weights to reproduce our results. You can evaluate them using the following command:
|
||||
> You can also replace `microsoft/tapex-base-finetuned-tabfact` with your local directory to evaluate your fine-tuned models. Notice that if the model has a larger size, you should reduce `per_device_eval_batch_size` to fit the memory requirement.
|
||||
|
||||
```bash
|
||||
export EXP_NAME=tabfact_tapex_base_eval
|
||||
|
||||
python run_tabfact_with_tapex.py \
|
||||
--do_eval \
|
||||
--model_name_or_path microsoft/tapex-base-finetuned-tabfact \
|
||||
--output_dir $EXP_NAME \
|
||||
--per_device_eval_batch_size 12 \
|
||||
--eval_accumulation_steps 6
|
||||
```
|
||||
|
||||
## Reproduced Results
|
||||
|
||||
We get the following results on the dev set of the benchmark with the previous commands:
|
||||
|
||||
| Task | Model Size | Metric | Result |
|
||||
|:---:|:---:|:---:|:---:|
|
||||
| WikiSQL (Weak) | Base | Denotation Accuracy | 88.1 |
|
||||
| WikiSQL (Weak) | Large | Denotation Accuracy | 89.5 |
|
||||
| WikiTableQuestion | Base | Denotation Accuracy | 47.1 |
|
||||
| WikiTableQuestion | Large | Denotation Accuracy | 57.2 |
|
||||
| TabFact | Base | Accuracy | 78.7 |
|
||||
| TabFact | Large | Accuracy | 83.6 |
|
||||
4
examples/research_projects/tapex/requirements.txt
Normal file
4
examples/research_projects/tapex/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
numpy
|
||||
datasets
|
||||
pandas
|
||||
nltk
|
||||
459
examples/research_projects/tapex/run_tabfact_with_tapex.py
Normal file
459
examples/research_projects/tapex/run_tabfact_with_tapex.py
Normal file
@@ -0,0 +1,459 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The Microsoft and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Fine-tuning the library models for tapex on table-based fact verification tasks.
|
||||
Adapted from script: https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-classification/run_glue.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
BartForSequenceClassification,
|
||||
DataCollatorWithPadding,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
TapexTokenizer,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.17.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
|
||||
Using `HfArgumentParser` we can turn this class
|
||||
into argparse arguments to be able to specify them on
|
||||
the command line.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default="tab_fact", metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default="tab_fact",
|
||||
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
train_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "A csv or a json file containing the training data."}
|
||||
)
|
||||
validation_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "A csv or a json file containing the validation data."}
|
||||
)
|
||||
test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is not None:
|
||||
pass
|
||||
elif self.train_file is None or self.validation_file is None:
|
||||
raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
|
||||
else:
|
||||
train_extension = self.train_file.split(".")[-1]
|
||||
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
validation_extension = self.validation_file.split(".")[-1]
|
||||
assert (
|
||||
validation_extension == train_extension
|
||||
), "`validation_file` should have the same extension (csv or json) as `train_file`."
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
||||
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For JSON files, this script will use the `question` column for the input question and `table` column for the corresponding table.
|
||||
#
|
||||
# If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
|
||||
# single column. You can easily tweak this behavior (see below)
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
raw_datasets = load_dataset(
|
||||
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
|
||||
)
|
||||
else:
|
||||
# Loading a dataset from your local files.
|
||||
# CSV/JSON training and evaluation files are needed.
|
||||
data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
|
||||
|
||||
# Get the test dataset: you can provide your own CSV/JSON test file (see below)
|
||||
# when you use `do_predict` without specifying a GLUE benchmark task.
|
||||
if training_args.do_predict:
|
||||
if data_args.test_file is not None:
|
||||
train_extension = data_args.train_file.split(".")[-1]
|
||||
test_extension = data_args.test_file.split(".")[-1]
|
||||
assert (
|
||||
test_extension == train_extension
|
||||
), "`test_file` should have the same extension (csv or json) as `train_file`."
|
||||
data_files["test"] = data_args.test_file
|
||||
else:
|
||||
raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
|
||||
|
||||
for key in data_files.keys():
|
||||
logger.info(f"load a local file for {key}: {data_files[key]}")
|
||||
|
||||
if data_args.train_file.endswith(".csv"):
|
||||
# Loading a dataset from local csv files
|
||||
raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir)
|
||||
else:
|
||||
# Loading a dataset from local json files
|
||||
raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir)
|
||||
# See more about loading any type of standard or custom dataset at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Labels
|
||||
label_list = raw_datasets["train"].features["label"].names
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# load tapex tokenizer
|
||||
tokenizer = TapexTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
model = BartForSequenceClassification.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Padding strategy
|
||||
if data_args.pad_to_max_length:
|
||||
padding = "max_length"
|
||||
else:
|
||||
# We will pad later, dynamically at batch creation, to the max sequence length in each batch
|
||||
padding = False
|
||||
|
||||
# Some models have set the order of the labels to use, so let's make sure we do use it.
|
||||
model.config.label2id = {"Refused": 0, "Entailed": 1}
|
||||
model.config.id2label = {0: "Refused", 1: "Entailed"}
|
||||
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
def preprocess_tabfact_function(examples):
|
||||
# Tokenize the texts
|
||||
def _convert_table_text_to_pandas(_table_text):
|
||||
"""Runs the structured pandas table object for _table_text.
|
||||
An example _table_text can be: round#clubs remaining\nfirst round#156\n
|
||||
"""
|
||||
_table_content = [_table_row.split("#") for _table_row in _table_text.strip("\n").split("\n")]
|
||||
_table_pd = pd.DataFrame.from_records(_table_content[1:], columns=_table_content[0])
|
||||
return _table_pd
|
||||
|
||||
questions = examples["statement"]
|
||||
tables = list(map(_convert_table_text_to_pandas, examples["table_text"]))
|
||||
result = tokenizer(tables, questions, padding=padding, max_length=max_seq_length, truncation=True)
|
||||
|
||||
result["label"] = examples["label"]
|
||||
return result
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
raw_datasets = raw_datasets.map(
|
||||
preprocess_tabfact_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
if training_args.do_train:
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = raw_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = raw_datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
if training_args.do_predict or data_args.test_file is not None:
|
||||
if "test" not in raw_datasets and "test_matched" not in raw_datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
predict_dataset = raw_datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
if training_args.do_train:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
||||
# predictions and label_ids field) and has to return a dictionary string to float.
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
|
||||
preds = np.argmax(preds, axis=1)
|
||||
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
|
||||
|
||||
# Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
elif training_args.fp16:
|
||||
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
else:
|
||||
data_collator = None
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
# Removing the `label` columns because it contains -1 and Trainer won't like that.
|
||||
predict_dataset = predict_dataset.remove_columns("label")
|
||||
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
|
||||
output_predict_file = os.path.join(training_args.output_dir, "predict_results_tabfact.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_predict_file, "w") as writer:
|
||||
logger.info("***** Predict Results *****")
|
||||
writer.write("index\tprediction\n")
|
||||
for index, item in enumerate(predictions):
|
||||
item = label_list[item]
|
||||
writer.write(f"{index}\t{item}\n")
|
||||
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
else:
|
||||
trainer.create_model_card(**kwargs)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
629
examples/research_projects/tapex/run_wikisql_with_tapex.py
Normal file
629
examples/research_projects/tapex/run_wikisql_with_tapex.py
Normal file
@@ -0,0 +1,629 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The Microsoft and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Fine-tuning the library models for tapex on table-based question answering tasks.
|
||||
Adapted from script: https://github.com/huggingface/transformers/blob/master/examples/pytorch/summarization/run_summarization.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
|
||||
import transformers
|
||||
from filelock import FileLock
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
BartForConditionalGeneration,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
TapexTokenizer,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.file_utils import is_offline_mode
|
||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
||||
from transformers.utils import check_min_version
|
||||
from wikisql_utils import _TYPE_CONVERTER, retrieve_wikisql_query_answer_tapas
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.17.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except (LookupError, OSError):
|
||||
if is_offline_mode():
|
||||
raise LookupError(
|
||||
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
||||
)
|
||||
with FileLock(".lock") as lock:
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Pretrained tokenizer name or path if not the same as model_name. "
|
||||
"By default we use BART-large tokenizer for TAPEX-large."
|
||||
},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default="wikisql", metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
|
||||
)
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
|
||||
"(a jsonlines or csv file)."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(training_args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For JSON files, this script will use the `question` column for the input question and `table` column for the corresponding table.
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
||||
else:
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
if data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
||||
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# IMPORTANT: the initial BART model's decoding is penalized by no_repeat_ngram_size, and thus
|
||||
# we should disable it here to avoid problematic generation
|
||||
config.no_repeat_ngram_size = 0
|
||||
config.max_length = 1024
|
||||
config.early_stopping = False
|
||||
|
||||
# load tapex tokenizer
|
||||
tokenizer = TapexTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
|
||||
# load Bart based Tapex model (default tapex-large)
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
if training_args.do_train:
|
||||
column_names = datasets["train"].column_names
|
||||
elif training_args.do_eval:
|
||||
column_names = datasets["validation"].column_names
|
||||
elif training_args.do_predict:
|
||||
column_names = datasets["test"].column_names
|
||||
else:
|
||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||
return
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
logger.warning(
|
||||
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||
)
|
||||
|
||||
def preprocess_tableqa_function(examples, is_training=False):
|
||||
"""
|
||||
The is_training FLAG is used to identify if we could use the supervision
|
||||
to truncate the table content if it is required.
|
||||
"""
|
||||
|
||||
# this function is specific for WikiSQL since the util function need the data structure
|
||||
# to retrieve the WikiSQL answer for each question
|
||||
def _convert_table_types(_table):
|
||||
"""Runs the type converter over the table cells."""
|
||||
ret_table = deepcopy(_table)
|
||||
types = ret_table["types"]
|
||||
ret_table["real_rows"] = ret_table["rows"]
|
||||
typed_rows = []
|
||||
for row in ret_table["rows"]:
|
||||
typed_row = []
|
||||
for column, cell_value in enumerate(row):
|
||||
typed_row.append(_TYPE_CONVERTER[types[column]](cell_value))
|
||||
typed_rows.append(typed_row)
|
||||
ret_table["rows"] = typed_rows
|
||||
return ret_table
|
||||
|
||||
questions = [question.lower() for question in examples["question"]]
|
||||
example_tables = examples["table"]
|
||||
example_sqls = examples["sql"]
|
||||
tables = [
|
||||
pd.DataFrame.from_records(example_table["rows"], columns=example_table["header"])
|
||||
for example_table in example_tables
|
||||
]
|
||||
|
||||
# using tapas utils to obtain wikisql answer
|
||||
answers = []
|
||||
for example_sql, example_table in zip(example_sqls, example_tables):
|
||||
tapas_table = _convert_table_types(example_table)
|
||||
answer_list: List[str] = retrieve_wikisql_query_answer_tapas(tapas_table, example_sql)
|
||||
# you can choose other delimiters to split each answer
|
||||
answers.append(answer_list)
|
||||
|
||||
# IMPORTANT: we cannot pass by answers during evaluation, answers passed during training are used to
|
||||
# truncate large tables in the train set!
|
||||
if is_training:
|
||||
model_inputs = tokenizer(
|
||||
table=tables,
|
||||
query=questions,
|
||||
answer=answers,
|
||||
max_length=data_args.max_source_length,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
)
|
||||
else:
|
||||
model_inputs = tokenizer(
|
||||
table=tables, query=questions, max_length=data_args.max_source_length, padding=padding, truncation=True
|
||||
)
|
||||
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(
|
||||
answer=[", ".join(answer) for answer in answers],
|
||||
max_length=max_target_length,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
# in training, we can use the answer as extra information to truncate large tables
|
||||
preprocess_tableqa_function_training = partial(preprocess_tableqa_function, is_training=True)
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_tableqa_function_training,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_eval:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_tableqa_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_predict:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
predict_dataset = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
predict_dataset = predict_dataset.map(
|
||||
preprocess_tableqa_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
def postprocess_text(preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
labels = [label.strip() for label in labels]
|
||||
|
||||
return preds, labels
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
preds, labels = eval_preds
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
if data_args.ignore_pad_token_for_loss:
|
||||
# Replace -100 in the labels as we can't decode them.
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
# Some simple post-processing
|
||||
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||
|
||||
delimiter = ", "
|
||||
|
||||
# define example evaluation
|
||||
def evaluate_example(predict_str: str, ground_str: str):
|
||||
predict_spans = predict_str.split(delimiter)
|
||||
ground_spans = ground_str.split(delimiter)
|
||||
predict_values = defaultdict(lambda: 0)
|
||||
ground_values = defaultdict(lambda: 0)
|
||||
for span in predict_spans:
|
||||
try:
|
||||
predict_values[float(span)] += 1
|
||||
except ValueError:
|
||||
predict_values[span.strip()] += 1
|
||||
for span in ground_spans:
|
||||
try:
|
||||
ground_values[float(span)] += 1
|
||||
except ValueError:
|
||||
ground_values[span.strip()] += 1
|
||||
is_correct = predict_values == ground_values
|
||||
return is_correct
|
||||
|
||||
def get_denotation_accuracy(predictions: List[str], references: List[str]):
|
||||
assert len(predictions) == len(references)
|
||||
correct_num = 0
|
||||
for predict_str, ground_str in zip(predictions, references):
|
||||
is_correct = evaluate_example(predict_str.lower(), ground_str.lower())
|
||||
if is_correct:
|
||||
correct_num += 1
|
||||
return correct_num / len(predictions)
|
||||
|
||||
accuracy = get_denotation_accuracy(decoded_preds, decoded_labels)
|
||||
result = {"denotation_accuracy": accuracy}
|
||||
|
||||
return result
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate(
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||
)
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
predict_results = trainer.predict(
|
||||
predict_dataset,
|
||||
metric_key_prefix="predict",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.num_beams,
|
||||
)
|
||||
metrics = predict_results.metrics
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
if training_args.predict_with_generate:
|
||||
predictions = tokenizer.batch_decode(
|
||||
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
predictions = [pred.strip() for pred in predictions]
|
||||
output_prediction_file = os.path.join(training_args.output_dir, "tapex_predictions.txt")
|
||||
with open(output_prediction_file, "w") as writer:
|
||||
writer.write("\n".join(predictions))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,605 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The Microsoft and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Fine-tuning the library models for tapex on table-based question answering tasks.
|
||||
Adapted from script: https://github.com/huggingface/transformers/blob/master/examples/pytorch/summarization/run_summarization.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
|
||||
import transformers
|
||||
from filelock import FileLock
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
BartForConditionalGeneration,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
TapexTokenizer,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.file_utils import is_offline_mode
|
||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
||||
from transformers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.17.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except (LookupError, OSError):
|
||||
if is_offline_mode():
|
||||
raise LookupError(
|
||||
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
||||
)
|
||||
with FileLock(".lock") as lock:
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Pretrained tokenizer name or path if not the same as model_name. "
|
||||
"By default we use BART-large tokenizer for TAPEX-large."
|
||||
},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default="wikitablequestions", metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
|
||||
)
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
|
||||
"(a jsonlines or csv file)."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(training_args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For JSON files, this script will use the `question` column for the input question and `table` column for the corresponding table.
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
||||
else:
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
if data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
||||
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# IMPORTANT: the initial BART model's decoding is penalized by no_repeat_ngram_size, and thus
|
||||
# we should disable it here to avoid problematic generation
|
||||
config.no_repeat_ngram_size = 0
|
||||
config.max_length = 1024
|
||||
config.early_stopping = False
|
||||
|
||||
# load tapex tokenizer
|
||||
tokenizer = TapexTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
|
||||
# load Bart based Tapex model (default tapex-large)
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
if training_args.do_train:
|
||||
column_names = datasets["train"].column_names
|
||||
elif training_args.do_eval:
|
||||
column_names = datasets["validation"].column_names
|
||||
elif training_args.do_predict:
|
||||
column_names = datasets["test"].column_names
|
||||
else:
|
||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||
return
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
logger.warning(
|
||||
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||
)
|
||||
|
||||
def preprocess_tableqa_function(examples, is_training=False):
|
||||
"""
|
||||
The is_training FLAG is used to identify if we could use the supervision
|
||||
to truncate the table content if it is required.
|
||||
"""
|
||||
|
||||
questions = [question.lower() for question in examples["question"]]
|
||||
example_tables = examples["table"]
|
||||
tables = [
|
||||
pd.DataFrame.from_records(example_table["rows"], columns=example_table["header"])
|
||||
for example_table in example_tables
|
||||
]
|
||||
|
||||
# using wikitablequestion's answer set
|
||||
answers = examples["answers"]
|
||||
|
||||
# IMPORTANT: we cannot pass by answers during evaluation, answers passed during training are used to
|
||||
# truncate large tables in the train set!
|
||||
if is_training:
|
||||
model_inputs = tokenizer(
|
||||
table=tables,
|
||||
query=questions,
|
||||
answer=answers,
|
||||
max_length=data_args.max_source_length,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
)
|
||||
else:
|
||||
model_inputs = tokenizer(
|
||||
table=tables, query=questions, max_length=data_args.max_source_length, padding=padding, truncation=True
|
||||
)
|
||||
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(
|
||||
answer=[", ".join(answer) for answer in answers],
|
||||
max_length=max_target_length,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
# in training, we can use the answer as extra information to truncate large tables
|
||||
preprocess_tableqa_function_training = partial(preprocess_tableqa_function, is_training=True)
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_tableqa_function_training,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_eval:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_tableqa_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_predict:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
predict_dataset = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
predict_dataset = predict_dataset.map(
|
||||
preprocess_tableqa_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
def postprocess_text(preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
labels = [label.strip() for label in labels]
|
||||
|
||||
return preds, labels
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
preds, labels = eval_preds
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
if data_args.ignore_pad_token_for_loss:
|
||||
# Replace -100 in the labels as we can't decode them.
|
||||
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
# Some simple post-processing
|
||||
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||
|
||||
delimiter = ", "
|
||||
|
||||
# define example evaluation
|
||||
def evaluate_example(predict_str: str, ground_str: str):
|
||||
predict_spans = predict_str.split(delimiter)
|
||||
ground_spans = ground_str.split(delimiter)
|
||||
predict_values = defaultdict(lambda: 0)
|
||||
ground_values = defaultdict(lambda: 0)
|
||||
for span in predict_spans:
|
||||
try:
|
||||
predict_values[float(span)] += 1
|
||||
except ValueError:
|
||||
predict_values[span.strip()] += 1
|
||||
for span in ground_spans:
|
||||
try:
|
||||
ground_values[float(span)] += 1
|
||||
except ValueError:
|
||||
ground_values[span.strip()] += 1
|
||||
_is_correct = predict_values == ground_values
|
||||
return _is_correct
|
||||
|
||||
def get_denotation_accuracy(predictions: List[str], references: List[str]):
|
||||
assert len(predictions) == len(references)
|
||||
correct_num = 0
|
||||
for predict_str, ground_str in zip(predictions, references):
|
||||
is_correct = evaluate_example(predict_str.lower(), ground_str.lower())
|
||||
if is_correct:
|
||||
correct_num += 1
|
||||
return correct_num / len(predictions)
|
||||
|
||||
accuracy = get_denotation_accuracy(decoded_preds, decoded_labels)
|
||||
result = {"denotation_accuracy": accuracy}
|
||||
|
||||
return result
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate(
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||
)
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
predict_results = trainer.predict(
|
||||
predict_dataset,
|
||||
metric_key_prefix="predict",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.num_beams,
|
||||
)
|
||||
metrics = predict_results.metrics
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
if training_args.predict_with_generate:
|
||||
predictions = tokenizer.batch_decode(
|
||||
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
predictions = [pred.strip() for pred in predictions]
|
||||
output_prediction_file = os.path.join(training_args.output_dir, "tapex_predictions.txt")
|
||||
with open(output_prediction_file, "w") as writer:
|
||||
writer.write("\n".join(predictions))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
259
examples/research_projects/tapex/wikisql_utils.py
Normal file
259
examples/research_projects/tapex/wikisql_utils.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The Microsoft, The Google and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import math
|
||||
import re
|
||||
|
||||
# The following script is adapted from the script of TaPas.
|
||||
# Original: https://github.com/google-research/tapas/master/wikisql_utils.py
|
||||
from typing import Any, List, Text
|
||||
|
||||
import six
|
||||
|
||||
|
||||
EMPTY_ANSWER = "none"
|
||||
EMPTY_ANSWER_AGG = "none"
|
||||
|
||||
|
||||
def _split_thousands(delimiter, value):
|
||||
split = value.split(delimiter)
|
||||
return len(split) > 1 and any(map(lambda x: len(x) == 3, split))
|
||||
|
||||
|
||||
def convert_to_float(value):
|
||||
"""Converts value to a float using a series of increasingly complex heuristics.
|
||||
Args:
|
||||
value: object that needs to be converted. Allowed types include
|
||||
float/int/strings.
|
||||
Returns:
|
||||
A float interpretation of value.
|
||||
Raises:
|
||||
ValueError if the float conversion of value fails.
|
||||
"""
|
||||
if isinstance(value, float):
|
||||
return value
|
||||
if isinstance(value, int):
|
||||
return float(value)
|
||||
if not isinstance(value, six.string_types):
|
||||
raise ValueError("Argument value is not a string. Can't parse it as float")
|
||||
sanitized = value
|
||||
|
||||
try:
|
||||
# Example: 1,000.7
|
||||
if "." in sanitized and "," in sanitized:
|
||||
return float(sanitized.replace(",", ""))
|
||||
# 1,000
|
||||
if "," in sanitized and _split_thousands(",", sanitized):
|
||||
return float(sanitized.replace(",", ""))
|
||||
# 5,5556
|
||||
if "," in sanitized and sanitized.count(",") == 1 and not _split_thousands(",", sanitized):
|
||||
return float(sanitized.replace(",", "."))
|
||||
# 0.0.0.1
|
||||
if sanitized.count(".") > 1:
|
||||
return float(sanitized.replace(".", ""))
|
||||
# 0,0,0,1
|
||||
if sanitized.count(",") > 1:
|
||||
return float(sanitized.replace(",", ""))
|
||||
return float(sanitized)
|
||||
except ValueError:
|
||||
# Avoid adding the sanitized value in the error message.
|
||||
raise ValueError("Unable to convert value to float")
|
||||
|
||||
|
||||
def _normalize_float(answer):
|
||||
if answer is None:
|
||||
return None
|
||||
try:
|
||||
value = convert_to_float(answer)
|
||||
if isinstance(value, float) and math.isnan(value):
|
||||
return None
|
||||
return value
|
||||
except ValueError:
|
||||
return answer.lower()
|
||||
|
||||
|
||||
_TYPE_CONVERTER = {
|
||||
"text": lambda x: x,
|
||||
"real": convert_to_float,
|
||||
}
|
||||
|
||||
|
||||
class _Aggregation(enum.Enum):
|
||||
"""Aggregations as defined by WikiSQL. Indexes match the data."""
|
||||
|
||||
NONE = 0
|
||||
MAX = 1
|
||||
MIN = 2
|
||||
COUNT = 3
|
||||
SUM = 4
|
||||
AVERAGE = 5
|
||||
|
||||
|
||||
class _Operator(enum.Enum):
|
||||
"""The boolean operators used by WikiSQL. Indexes match the data."""
|
||||
|
||||
EQUALS = 0
|
||||
GREATER = 1
|
||||
LESSER = 2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Condition:
|
||||
"""Represents an SQL where clauses (e.g A = "a" or B > 5)."""
|
||||
|
||||
column: Text
|
||||
operator: _Operator
|
||||
cmp_value: Any
|
||||
|
||||
|
||||
_TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE | re.MULTILINE | re.DOTALL)
|
||||
|
||||
|
||||
def _normalize_for_match(x):
|
||||
return [t for t in _TOKENIZER.findall(x.lower())]
|
||||
|
||||
|
||||
def _compare(operator, src, tgt):
|
||||
if operator == _Operator.EQUALS:
|
||||
return src == tgt
|
||||
elif operator == _Operator.GREATER:
|
||||
return src > tgt
|
||||
elif operator == _Operator.LESSER:
|
||||
return src < tgt
|
||||
raise ValueError(f"Unknown operator: {operator}")
|
||||
|
||||
|
||||
def _parse_value(table, column, cell_value):
|
||||
"""Convert numeric values to floats and keeps everything else as string."""
|
||||
types = table["types"]
|
||||
return _TYPE_CONVERTER[types[column]](cell_value)
|
||||
|
||||
|
||||
def _is_string(x):
|
||||
return isinstance(x, str)
|
||||
|
||||
|
||||
def _respect_conditions(table, row, conditions):
|
||||
"""True if 'row' satisfies all 'conditions'."""
|
||||
for cond in conditions:
|
||||
table_value = row[cond.column]
|
||||
|
||||
cmp_value = _parse_value(table, cond.column, cond.cmp_value)
|
||||
|
||||
if _is_string(table_value) and _is_string(cmp_value):
|
||||
table_value = _normalize_for_match(table_value)
|
||||
cmp_value = _normalize_for_match(cmp_value)
|
||||
|
||||
if not isinstance(table_value, type(cmp_value)):
|
||||
raise ValueError("Type difference {} != {}".format(type(table_value), type(cmp_value)))
|
||||
|
||||
if not _compare(cond.operator, table_value, cmp_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_float_answer(table, answer_coordinates, aggregation_op):
|
||||
"""Applies operation to produce reference float answer."""
|
||||
if not answer_coordinates:
|
||||
if aggregation_op == _Aggregation.COUNT:
|
||||
return 0.0
|
||||
else:
|
||||
return EMPTY_ANSWER_AGG
|
||||
|
||||
# Count can support non numeric answers.
|
||||
if aggregation_op == _Aggregation.COUNT:
|
||||
return float(len(answer_coordinates))
|
||||
|
||||
# If we have just one answer, if float returns it or try a conversion.
|
||||
values = [table["rows"][i][j] for (i, j) in answer_coordinates]
|
||||
if len(answer_coordinates) == 1:
|
||||
try:
|
||||
return convert_to_float(values[0])
|
||||
except ValueError as e:
|
||||
if aggregation_op != _Aggregation.NONE:
|
||||
raise e
|
||||
|
||||
if aggregation_op == _Aggregation.NONE:
|
||||
return None
|
||||
|
||||
# Other aggregation only support numeric values. Bail out if we have strings.
|
||||
if not all((isinstance(v, (int, float)) for v in values)):
|
||||
return None
|
||||
|
||||
if aggregation_op == _Aggregation.SUM:
|
||||
return float(sum(values))
|
||||
elif aggregation_op == _Aggregation.AVERAGE:
|
||||
return sum(values) / len(answer_coordinates)
|
||||
else:
|
||||
raise ValueError(f"Unknown aggregation: {aggregation_op}")
|
||||
|
||||
|
||||
def _get_answer_coordinates(table, sql_query):
|
||||
"""Retrieves references coordinates by executing SQL."""
|
||||
# MAX and MIN are automatically supported by the model.
|
||||
aggregation_op_index = sql_query["agg"]
|
||||
if aggregation_op_index >= 3:
|
||||
aggregation_op = _Aggregation(aggregation_op_index)
|
||||
else:
|
||||
aggregation_op = _Aggregation.NONE
|
||||
|
||||
target_column = sql_query["sel"]
|
||||
conditions = [
|
||||
_Condition(column, _Operator(operator), cmp_value)
|
||||
for column, operator, cmp_value in zip(
|
||||
sql_query["conds"]["column_index"], sql_query["conds"]["operator_index"], sql_query["conds"]["condition"]
|
||||
)
|
||||
]
|
||||
|
||||
indices = []
|
||||
for row in range(len(table["rows"])):
|
||||
if _respect_conditions(table, table["rows"][row], conditions):
|
||||
indices.append((row, target_column))
|
||||
|
||||
if not indices:
|
||||
return [], aggregation_op
|
||||
|
||||
if len(indices) == 1:
|
||||
return indices, aggregation_op
|
||||
|
||||
# Parsing of MIN/MAX.
|
||||
if aggregation_op_index in (1, 2):
|
||||
operators = {2: min, 1: max}
|
||||
values = [(table["rows"][i][j], index) for index, (i, j) in enumerate(indices)]
|
||||
reduced = functools.reduce(operators[sql_query["agg"]], values)
|
||||
|
||||
ret = [indices[reduced[1]]]
|
||||
return ret, _Aggregation.NONE
|
||||
|
||||
return indices, aggregation_op
|
||||
|
||||
|
||||
def _get_answer_text(table, answer_coordinates, float_answer):
|
||||
if float_answer is not None:
|
||||
return [str(float_answer)]
|
||||
return [str(table["real_rows"][r][c]) for r, c in answer_coordinates]
|
||||
|
||||
|
||||
def retrieve_wikisql_query_answer_tapas(table, example) -> List:
|
||||
answer_coordinates, aggregation_op = _get_answer_coordinates(table, example)
|
||||
float_answer = _get_float_answer(table, answer_coordinates, aggregation_op)
|
||||
answer_text = _get_answer_text(table, answer_coordinates, float_answer)
|
||||
# keep the original data the same with TaPas
|
||||
if len(answer_text) == 0:
|
||||
answer_text = [EMPTY_ANSWER]
|
||||
return answer_text
|
||||
Reference in New Issue
Block a user