Merge branch 'master' into fix-xlnet-squad2.0
This commit is contained in:
@@ -3,6 +3,17 @@
|
||||
In this section a few examples are put together. All of these examples work for several models, making use of the very
|
||||
similar API between the different models.
|
||||
|
||||
**Important**
|
||||
To run the latest versions of the examples, you have to install from source and install some specific requirements for the examples.
|
||||
Execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
pip install [--editable] .
|
||||
pip install -r ./examples/requirements.txt
|
||||
```
|
||||
|
||||
| Section | Description |
|
||||
|----------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [TensorFlow 2.0 models on GLUE](#TensorFlow-2.0-Bert-models-on-GLUE) | Examples running BERT TensorFlow 2.0 model on the GLUE tasks.
|
||||
@@ -12,7 +23,7 @@ similar API between the different models.
|
||||
| [SQuAD](#squad) | Using BERT/RoBERTa/XLNet/XLM for question answering, examples with distributed training. |
|
||||
| [Multiple Choice](#multiple-choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
|
||||
| [Named Entity Recognition](#named-entity-recognition) | Using BERT for Named Entity Recognition (NER) on the CoNLL 2003 dataset, examples with distributed training. |
|
||||
| [Abstractive summarization](#abstractive-summarization) | Fine-tuning the library models for abstractive summarization tasks on the CNN/Daily Mail dataset. |
|
||||
| [XNLI](#xnli) | Examples running BERT/XLM on the XNLI benchmark. |
|
||||
|
||||
## TensorFlow 2.0 Bert models on GLUE
|
||||
|
||||
@@ -506,7 +517,8 @@ Larger batch size may improve the performance while costing more memory.
|
||||
|
||||
## Named Entity Recognition
|
||||
|
||||
Based on the script [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py).
|
||||
Based on the scripts [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py) for Pytorch and
|
||||
[`run_tf_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_tf_ner.py) for Tensorflow 2.
|
||||
This example fine-tune Bert Multilingual on GermEval 2014 (German NER).
|
||||
Details and results for the fine-tuning provided by @stefan-it.
|
||||
|
||||
@@ -551,7 +563,7 @@ The GermEval 2014 dataset has much more labels than CoNLL-2002/2003 datasets, so
|
||||
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
|
||||
```
|
||||
|
||||
### Training
|
||||
### Prepare the run
|
||||
|
||||
Additional environment variables must be set:
|
||||
|
||||
@@ -563,6 +575,8 @@ export SAVE_STEPS=750
|
||||
export SEED=1
|
||||
```
|
||||
|
||||
### Run the Pytorch version
|
||||
|
||||
To start training, just run:
|
||||
|
||||
```bash
|
||||
@@ -583,7 +597,7 @@ python3 run_ner.py --data_dir ./ \
|
||||
|
||||
If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
|
||||
|
||||
### Evaluation
|
||||
#### Evaluation
|
||||
|
||||
Evaluation on development dataset outputs the following for our example:
|
||||
|
||||
@@ -605,7 +619,7 @@ On the test dataset the following results could be achieved:
|
||||
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
|
||||
```
|
||||
|
||||
### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
|
||||
#### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
|
||||
|
||||
Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):
|
||||
|
||||
@@ -615,30 +629,108 @@ Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) a
|
||||
| `roberta-large` | 95.96 | 91.87
|
||||
| `distilbert-base-uncased` | 94.34 | 90.32
|
||||
|
||||
## Abstractive summarization
|
||||
### Run the Tensorflow 2 version
|
||||
|
||||
Based on the script
|
||||
[`run_summarization_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py).
|
||||
|
||||
Before running this script you should download **both** CNN and Daily Mail
|
||||
datasets from [Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the
|
||||
links next to "Stories") in the same folder. Then uncompress the archives by running:
|
||||
To start training, just run:
|
||||
|
||||
```bash
|
||||
tar -xvf cnn_stories.tgz && tar -xvf dailymail_stories.tgz
|
||||
python3 run_tf_ner.py --data_dir ./ \
|
||||
--model_type bert \
|
||||
--labels ./labels.txt \
|
||||
--model_name_or_path $BERT_MODEL \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_seq_length $MAX_LENGTH \
|
||||
--num_train_epochs $NUM_EPOCHS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--save_steps $SAVE_STEPS \
|
||||
--seed $SEED \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict
|
||||
```
|
||||
|
||||
note that the finetuning script **will not work** if you do not download both
|
||||
datasets. We will refer as `$DATA_PATH` the path to where you uncompressed both
|
||||
archive.
|
||||
Such as the Pytorch version, if your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
|
||||
|
||||
#### Evaluation
|
||||
|
||||
Evaluation on development dataset outputs the following for our example:
|
||||
```bash
|
||||
precision recall f1-score support
|
||||
|
||||
LOCderiv 0.7619 0.6154 0.6809 52
|
||||
PERpart 0.8724 0.8997 0.8858 4057
|
||||
OTHpart 0.9360 0.9466 0.9413 711
|
||||
ORGpart 0.7015 0.6989 0.7002 269
|
||||
LOCpart 0.7668 0.8488 0.8057 496
|
||||
LOC 0.8745 0.9191 0.8963 235
|
||||
ORGderiv 0.7723 0.8571 0.8125 91
|
||||
OTHderiv 0.4800 0.6667 0.5581 18
|
||||
OTH 0.5789 0.6875 0.6286 16
|
||||
PERderiv 0.5385 0.3889 0.4516 18
|
||||
PER 0.5000 0.5000 0.5000 2
|
||||
ORG 0.0000 0.0000 0.0000 3
|
||||
|
||||
micro avg 0.8574 0.8862 0.8715 5968
|
||||
macro avg 0.8575 0.8862 0.8713 5968
|
||||
```
|
||||
|
||||
On the test dataset the following results could be achieved:
|
||||
```bash
|
||||
precision recall f1-score support
|
||||
|
||||
PERpart 0.8847 0.8944 0.8896 9397
|
||||
OTHpart 0.9376 0.9353 0.9365 1639
|
||||
ORGpart 0.7307 0.7044 0.7173 697
|
||||
LOC 0.9133 0.9394 0.9262 561
|
||||
LOCpart 0.8058 0.8157 0.8107 1150
|
||||
ORG 0.0000 0.0000 0.0000 8
|
||||
OTHderiv 0.5882 0.4762 0.5263 42
|
||||
PERderiv 0.6571 0.5227 0.5823 44
|
||||
OTH 0.4906 0.6667 0.5652 39
|
||||
ORGderiv 0.7016 0.7791 0.7383 172
|
||||
LOCderiv 0.8256 0.6514 0.7282 109
|
||||
PER 0.0000 0.0000 0.0000 11
|
||||
|
||||
micro avg 0.8722 0.8774 0.8748 13869
|
||||
macro avg 0.8712 0.8774 0.8740 13869
|
||||
```
|
||||
|
||||
## XNLI
|
||||
|
||||
Based on the script [`run_xnli.py`](https://github.com/huggingface/transformers/blob/master/examples/run_xnli.py).
|
||||
|
||||
[XNLI](https://www.nyu.edu/projects/bowman/xnli/) is crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-ressource language such as English and low-ressource languages such as Swahili).
|
||||
|
||||
#### Fine-tuning on XNLI
|
||||
|
||||
This example code fine-tunes mBERT (multi-lingual BERT) on the XNLI dataset. It runs in 106 mins
|
||||
on a single tesla V100 16GB. The data for XNLI can be downloaded with the following links and should be both saved (and un-zipped) in a
|
||||
`$XNLI_DIR` directory.
|
||||
|
||||
* [XNLI 1.0](https://www.nyu.edu/projects/bowman/xnli/XNLI-1.0.zip)
|
||||
* [XNLI-MT 1.0](https://www.nyu.edu/projects/bowman/xnli/XNLI-MT-1.0.zip)
|
||||
|
||||
```bash
|
||||
export DATA_PATH=/path/to/dataset/
|
||||
export XNLI_DIR=/path/to/XNLI
|
||||
|
||||
python run_summarization_finetuning.py \
|
||||
--output_dir=output \
|
||||
--model_type=bert2bert \
|
||||
--model_name_or_path=bert2bert \
|
||||
--do_train \
|
||||
--data_path=$DATA_PATH \
|
||||
python run_xnli.py \
|
||||
--model_type bert \
|
||||
--model_name_or_path bert-base-multilingual-cased \
|
||||
--language de \
|
||||
--train_language en \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--data_dir $XNLI_DIR \
|
||||
--per_gpu_train_batch_size 32 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 2.0 \
|
||||
--max_seq_length 128 \
|
||||
--output_dir /tmp/debug_xnli/ \
|
||||
--save_steps -1
|
||||
```
|
||||
|
||||
Training with the previously defined hyper-parameters yields the following results on the **test** set:
|
||||
|
||||
```bash
|
||||
acc = 0.7093812375249501
|
||||
```
|
||||
|
||||
48
examples/contrib/run_camembert.py
Normal file
48
examples/contrib/run_camembert.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from pathlib import Path
|
||||
import tarfile
|
||||
import urllib.request
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.tokenization_camembert import CamembertTokenizer
|
||||
from transformers.modeling_camembert import CamembertForMaskedLM
|
||||
|
||||
|
||||
def fill_mask(masked_input, model, tokenizer, topk=5):
|
||||
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
|
||||
assert masked_input.count('<mask>') == 1
|
||||
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
|
||||
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
|
||||
logits = logits[0, masked_index, :]
|
||||
prob = logits.softmax(dim=0)
|
||||
values, indices = prob.topk(k=topk, dim=0)
|
||||
topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item())
|
||||
for i in range(len(indices))])
|
||||
masked_token = tokenizer.mask_token
|
||||
topk_filled_outputs = []
|
||||
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')):
|
||||
predicted_token = predicted_token_bpe.replace('\u2581', ' ')
|
||||
if " {0}".format(masked_token) in masked_input:
|
||||
topk_filled_outputs.append((
|
||||
masked_input.replace(
|
||||
' {0}'.format(masked_token), predicted_token
|
||||
),
|
||||
values[index].item(),
|
||||
predicted_token,
|
||||
))
|
||||
else:
|
||||
topk_filled_outputs.append((
|
||||
masked_input.replace(masked_token, predicted_token),
|
||||
values[index].item(),
|
||||
predicted_token,
|
||||
))
|
||||
return topk_filled_outputs
|
||||
|
||||
|
||||
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
|
||||
model = CamembertForMaskedLM.from_pretrained('camembert-base')
|
||||
model.eval()
|
||||
|
||||
masked_input = "Le camembert est <mask> :)"
|
||||
print(fill_mask(masked_input, model, tokenizer, topk=3))
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
This folder contains the original code used to train Distil* as well as examples showcasing how to use DistilBERT, DistilRoBERTa and DistilGPT2.
|
||||
|
||||
**December 6th, 2019 - Update** We release **DistilmBERT**: 92% of `bert-base-multilingual-cased` on XNLI. The model supports 104 different languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages).
|
||||
|
||||
**November 19th, 2019 - Update** We release German **DistilBERT**: 98.8% of `bert-base-german-dbmdz-cased` on NER tasks.
|
||||
|
||||
**October 23rd, 2019 - Update** We release **DistilRoBERTa**: 95% of `RoBERTa-base`'s performance on GLUE, twice as fast as RoBERTa while being 35% smaller.
|
||||
|
||||
**October 3rd, 2019 - Update** We release our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108) explaining our approach on **DistilBERT**. It includes updated results and further experiments. We applied the same method to GPT2 and release the weights of **DistilGPT2**. DistilGPT2 is two times faster and 33% smaller than GPT2. **The paper superseeds our [previous blogpost](https://medium.com/huggingface/distilbert-8cf3380435b5) with a different distillation loss and better performances. Please use the paper as a reference when comparing/reporting results on DistilBERT.**
|
||||
@@ -15,8 +19,9 @@ Distil* is a class of compressed models that started with DistilBERT. DistilBERT
|
||||
|
||||
We have applied the same method to other Transformer architectures and released the weights:
|
||||
- GPT2: on the [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) benchmark, GPT2 reaches a perplexity on the test set of 15.0 compared to 18.5 for **DistilGPT2** (after fine-tuning on the train set).
|
||||
- RoBERTa: **DistilRoBERTa** reaches 95% of `RoBERTa-base` performance on GLUE while being twice faster and 35% smaller.
|
||||
- and more to come! 🤗🤗🤗
|
||||
- RoBERTa: **DistilRoBERTa** reaches 95% of `RoBERTa-base`'s performance on GLUE while being twice faster and 35% smaller.
|
||||
- German BERT: **German DistilBERT** reaches 99% of `bert-base-german-dbmdz-cased`'s performance on German NER (CoNLL-2003).
|
||||
- Multilingual BERT: **DistilmBERT** reaches 92% of Multilingual BERT's performance on XNLI while being twice faster and 25% smaller. The model supports 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages).
|
||||
|
||||
For more information on DistilBERT, please refer to our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108).
|
||||
|
||||
@@ -27,7 +32,7 @@ Here are the results on the dev sets of GLUE:
|
||||
| BERT-base | **77.6** | 48.9 | 84.3 | 88.6 | 89.3 | 89.5 | 71.3 | 91.7 | 91.2 | 43.7 |
|
||||
| DistilBERT | **76.8** | 49.1 | 81.8 | 90.2 | 90.2 | 89.2 | 62.9 | 92.7 | 90.7 | 44.4 |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| RoBERTa-base (reported) | **83.2**/**86.4**<sup>2</sup> | 63.6 | 87.6 | 90.2 | 92.8 | 91.9 | 78.7 | 94.8 | 91.2 | 57.7<sup>3</sup> |
|
||||
| RoBERTa-base (reported) | **83.2**/**86.4**<sup>2</sup> | 63.6 | 87.6 | 90.2 | 92.8 | 91.9 | 78.7 | 94.8 | 91.2 | 57.7<sup>3</sup> |
|
||||
| DistilRoBERTa<sup>1</sup> | **79.0**/**82.3**<sup>2</sup> | 59.4 | 83.9 | 86.6 | 90.8 | 89.4 | 67.9 | 92.5 | 88.3 | 52.1 |
|
||||
|
||||
<sup>1</sup> We did not use the MNLI checkpoint for fine-tuning but directy perform transfer learning on the pre-trained DistilRoBERTa.
|
||||
@@ -36,6 +41,14 @@ Here are the results on the dev sets of GLUE:
|
||||
|
||||
<sup>3</sup> We compute this score ourselves for completeness.
|
||||
|
||||
Here are the results on the *test* sets for 6 of the languages available in XNLI. The results are computed in the zero shot setting (trained on the English portion and evaluated on the target language portion):
|
||||
|
||||
| Model | English | Spanish | Chinese | German | Arabic | Urdu |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
|
||||
| mBERT base cased (computed) | 82.1 | 74.6 | 69.1 | 72.3 | 66.4 | 58.5 |
|
||||
| mBERT base uncased (reported)| 81.4 | 74.3 | 63.8 | 70.5 | 62.1 | 58.3 |
|
||||
| DistilmBERT | 78.2 | 69.1 | 64.0 | 66.3 | 59.1 | 54.7 |
|
||||
|
||||
## Setup
|
||||
|
||||
This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`.
|
||||
@@ -45,13 +58,14 @@ This part of the library has only be tested with Python3.6+. There are few speci
|
||||
|
||||
## How to use DistilBERT
|
||||
|
||||
Transformers includes two pre-trained Distil* models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
|
||||
Transformers includes five pre-trained Distil* models, currently only provided for English and German (we are investigating the possibility to train and release a multilingual version of DistilBERT):
|
||||
|
||||
- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
|
||||
- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knwoledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
|
||||
- `distilbert-base-german-cased`: DistilBERT German language model pretrained on 1/2 of the data used to pretrain Bert using distillation with the supervision of the `bert-base-german-dbmdz-cased` version of German DBMDZ Bert. For NER tasks the model reaches a F1 score of 83.49 on the CoNLL-2003 test set (for comparison, `bert-base-german-dbmdz-cased` reaches a 84.52 F1 score), and a F1 score of 85.23 on the GermEval 2014 test set (`bert-base-german-dbmdz-cased` reaches a 86.89 F1 score).
|
||||
- `distilgpt2`: DistilGPT2 English language model pretrained with the supervision of `gpt2` (the smallest version of GPT2) on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset. The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 124M parameters for GPT2). On average, DistilGPT2 is two times faster than GPT2.
|
||||
- `distilroberta-base`: DistilRoBERTa English language model pretrained with the supervision of `roberta-base` solely on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset (it is ~4 times less training data than the teacher RoBERTa). The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 125M parameters for RoBERTa-base). On average DistilRoBERTa is twice as fast as Roberta-base.
|
||||
- and more to come! 🤗🤗🤗
|
||||
- `distilbert-base-multilingual-cased`: DistilmBERT multilingual model pretrained with the supervision of `bert-base-multilingual-cased` on the concatenation of Wikipedia in 104 different languages. The model supports the 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages). The model has 6 layers, 768 dimension and 12 heads, totalizing 134M parameters (compared to 177M parameters for mBERT-base). On average DistilmBERT is twice as fast as mBERT-base.
|
||||
|
||||
Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
|
||||
|
||||
@@ -67,6 +81,7 @@ last_hidden_states = outputs[0] # The last hidden-state is the first element of
|
||||
Similarly, using the other Distil* models simply consists in calling the base classes with a different pretrained checkpoint:
|
||||
- DistilGPT2: `model = GPT2Model.from_pretrained('distilgpt2')`
|
||||
- DistilRoBERTa: `model = RobertaModel.from_pretrained('distilroberta-base')`
|
||||
- DistilmBERT: `model = DistilBertModel.from_pretrained('distilbert-base-multilingual-cased')`
|
||||
|
||||
|
||||
## How to train Distil*
|
||||
|
||||
@@ -21,7 +21,6 @@ import psutil
|
||||
import time
|
||||
from tqdm import trange, tqdm
|
||||
import numpy as np
|
||||
import psutil
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -3,4 +3,4 @@ tensorboard>=1.14.0
|
||||
tensorboardX==1.8
|
||||
psutil==5.6.3
|
||||
scipy==1.3.1
|
||||
transformers==2.0.0
|
||||
transformers
|
||||
|
||||
54
examples/pplm/README.md
Normal file
54
examples/pplm/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Plug and Play Language Models: a Simple Approach to Controlled Text Generation
|
||||
|
||||
Authors: [Sumanth Dathathri](https://dathath.github.io/), [Andrea Madotto](https://andreamad8.github.io/), Janice Lan, Jane Hung, Eric Frank, [Piero Molino](https://w4nderlu.st/), [Jason Yosinski](http://yosinski.com/), and [Rosanne Liu](http://www.rosanneliu.com/)
|
||||
|
||||
This folder contains the original code used to run the Plug and Play Language Model (PPLM).
|
||||
|
||||
Paper link: https://arxiv.org/abs/1912.02164
|
||||
|
||||
Blog link: https://eng.uber.com/pplm
|
||||
|
||||
Please check out the repo under uber-research for more information: https://github.com/uber-research/PPLM
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/transformers && cd transformers
|
||||
pip install [--editable] .
|
||||
pip install nltk torchtext # additional requirements.
|
||||
cd examples/pplm
|
||||
```
|
||||
|
||||
## PPLM-BoW
|
||||
|
||||
### Example command for bag-of-words control
|
||||
|
||||
```bash
|
||||
python run_pplm.py -B military --cond_text "The potato" --length 50 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.03 --window_length 5 --kl_scale 0.01 --gm_scale 0.99 --colorama --sample
|
||||
```
|
||||
|
||||
### Tuning hyperparameters for bag-of-words control
|
||||
|
||||
1. Increase `--stepsize` to intensify topic control, and decrease its value to soften the control. `--stepsize 0` recovers the original uncontrolled GPT-2 model.
|
||||
|
||||
2. If the language being generated is repetitive (For e.g. "science science experiment experiment"), there are several options to consider: </br>
|
||||
a) Reduce the `--stepsize` </br>
|
||||
b) Increase `--kl_scale` (the KL-loss coefficient) or decrease `--gm_scale` (the gm-scaling term) </br>
|
||||
c) Add `--grad-length xx` where xx is an (integer <= length, e.g. `--grad-length 30`).</br>
|
||||
|
||||
|
||||
## PPLM-Discrim
|
||||
|
||||
### Example command for discriminator based sentiment control
|
||||
|
||||
```bash
|
||||
python run_pplm.py -D sentiment --class_label 2 --cond_text "My dog died" --length 50 --gamma 1.0 --num_iterations 10 --num_samples 10 --stepsize 0.04 --kl_scale 0.01 --gm_scale 0.95 --sample
|
||||
```
|
||||
|
||||
### Tuning hyperparameters for discriminator control
|
||||
|
||||
1. Increase `--stepsize` to intensify topic control, and decrease its value to soften the control. `--stepsize 0` recovers the original uncontrolled GPT-2 model.
|
||||
|
||||
2. Use `--class_label 3` for negative, and `--class_label 2` for positive
|
||||
|
||||
BIN
examples/pplm/imgs/headfigure.png
Normal file
BIN
examples/pplm/imgs/headfigure.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 653 KiB |
BIN
examples/pplm/imgs/wooly.png
Normal file
BIN
examples/pplm/imgs/wooly.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 664 KiB |
18
examples/pplm/pplm_classification_head.py
Normal file
18
examples/pplm/pplm_classification_head.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
class ClassificationHead(torch.nn.Module):
|
||||
"""Classification Head for transformer encoders"""
|
||||
|
||||
def __init__(self, class_size, embed_size):
|
||||
super(ClassificationHead, self).__init__()
|
||||
self.class_size = class_size
|
||||
self.embed_size = embed_size
|
||||
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
||||
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
||||
self.mlp = torch.nn.Linear(embed_size, class_size)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
# hidden_state = F.relu(self.mlp1(hidden_state))
|
||||
# hidden_state = self.mlp2(hidden_state)
|
||||
logits = self.mlp(hidden_state)
|
||||
return logits
|
||||
879
examples/pplm/run_pplm.py
Normal file
879
examples/pplm/run_pplm.py
Normal file
@@ -0,0 +1,879 @@
|
||||
#! /usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
#Copyright (c) 2019 Uber Technologies, Inc.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
#http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
"""
|
||||
Example command with bag of words:
|
||||
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
|
||||
|
||||
Example command with discriminator:
|
||||
python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from operator import add
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from tqdm import trange
|
||||
|
||||
from transformers import GPT2Tokenizer
|
||||
from transformers.file_utils import cached_path
|
||||
from transformers.modeling_gpt2 import GPT2LMHeadModel
|
||||
from pplm_classification_head import ClassificationHead
|
||||
|
||||
PPLM_BOW = 1
|
||||
PPLM_DISCRIM = 2
|
||||
PPLM_BOW_DISCRIM = 3
|
||||
SMALL_CONST = 1e-15
|
||||
BIG_CONST = 1e10
|
||||
|
||||
BAG_OF_WORDS_ARCHIVE_MAP = {
|
||||
'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
|
||||
'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
|
||||
'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
|
||||
'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
|
||||
'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
|
||||
'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
|
||||
'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
|
||||
}
|
||||
|
||||
DISCRIMINATOR_MODELS_PARAMS = {
|
||||
"clickbait": {
|
||||
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt",
|
||||
"class_size": 2,
|
||||
"embed_size": 1024,
|
||||
"class_vocab": {"non_clickbait": 0, "clickbait": 1},
|
||||
"default_class": 1,
|
||||
"pretrained_model": "gpt2-medium",
|
||||
},
|
||||
"sentiment": {
|
||||
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt",
|
||||
"class_size": 5,
|
||||
"embed_size": 1024,
|
||||
"class_vocab": {"very_positive": 2, "very_negative": 3},
|
||||
"default_class": 3,
|
||||
"pretrained_model": "gpt2-medium",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def to_var(x, requires_grad=False, volatile=False, device='cuda'):
|
||||
if torch.cuda.is_available() and device == 'cuda':
|
||||
x = x.cuda()
|
||||
elif device != 'cuda':
|
||||
x = x.to(device)
|
||||
return Variable(x, requires_grad=requires_grad, volatile=volatile)
|
||||
|
||||
|
||||
def top_k_filter(logits, k, probs=False):
|
||||
"""
|
||||
Masks everything but the k top entries as -infinity (1e10).
|
||||
Used to mask logits such that e^-infinity -> 0 won't contribute to the
|
||||
sum of the denominator.
|
||||
"""
|
||||
if k == 0:
|
||||
return logits
|
||||
else:
|
||||
values = torch.topk(logits, k)[0]
|
||||
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
|
||||
if probs:
|
||||
return torch.where(logits < batch_mins,
|
||||
torch.ones_like(logits) * 0.0, logits)
|
||||
return torch.where(logits < batch_mins,
|
||||
torch.ones_like(logits) * -BIG_CONST,
|
||||
logits)
|
||||
|
||||
|
||||
def perturb_past(
|
||||
past,
|
||||
model,
|
||||
last,
|
||||
unpert_past=None,
|
||||
unpert_logits=None,
|
||||
accumulated_hidden=None,
|
||||
grad_norms=None,
|
||||
stepsize=0.01,
|
||||
one_hot_bows_vectors=None,
|
||||
classifier=None,
|
||||
class_label=None,
|
||||
loss_type=0,
|
||||
num_iterations=3,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
kl_scale=0.01,
|
||||
device='cuda',
|
||||
):
|
||||
# Generate inital perturbed past
|
||||
grad_accumulator = [
|
||||
(np.zeros(p.shape).astype("float32"))
|
||||
for p in past
|
||||
]
|
||||
|
||||
if accumulated_hidden is None:
|
||||
accumulated_hidden = 0
|
||||
|
||||
if decay:
|
||||
decay_mask = torch.arange(
|
||||
0.,
|
||||
1.0 + SMALL_CONST,
|
||||
1.0 / (window_length)
|
||||
)[1:]
|
||||
else:
|
||||
decay_mask = 1.0
|
||||
|
||||
# TODO fix this comment (SUMANTH)
|
||||
# Generate a mask is gradient perturbated is based on a past window
|
||||
_, _, _, curr_length, _ = past[0].shape
|
||||
|
||||
if curr_length > window_length and window_length > 0:
|
||||
ones_key_val_shape = (
|
||||
tuple(past[0].shape[:-2])
|
||||
+ tuple([window_length])
|
||||
+ tuple(past[0].shape[-1:])
|
||||
)
|
||||
|
||||
zeros_key_val_shape = (
|
||||
tuple(past[0].shape[:-2])
|
||||
+ tuple([curr_length - window_length])
|
||||
+ tuple(past[0].shape[-1:])
|
||||
)
|
||||
|
||||
ones_mask = torch.ones(ones_key_val_shape)
|
||||
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
|
||||
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
|
||||
|
||||
window_mask = torch.cat(
|
||||
(ones_mask, torch.zeros(zeros_key_val_shape)),
|
||||
dim=-2
|
||||
).to(device)
|
||||
else:
|
||||
window_mask = torch.ones_like(past[0]).to(device)
|
||||
|
||||
# accumulate perturbations for num_iterations
|
||||
loss_per_iter = []
|
||||
new_accumulated_hidden = None
|
||||
for i in range(num_iterations):
|
||||
print("Iteration ", i + 1)
|
||||
curr_perturbation = [
|
||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
||||
for p_ in grad_accumulator
|
||||
]
|
||||
|
||||
# Compute hidden using perturbed past
|
||||
perturbed_past = list(map(add, past, curr_perturbation))
|
||||
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
||||
all_logits, _, all_hidden = model(last, past=perturbed_past)
|
||||
hidden = all_hidden[-1]
|
||||
new_accumulated_hidden = accumulated_hidden + torch.sum(
|
||||
hidden,
|
||||
dim=1
|
||||
).detach()
|
||||
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||
logits = all_logits[:, -1, :]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
loss = 0.0
|
||||
loss_list = []
|
||||
if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
|
||||
for one_hot_bow in one_hot_bows_vectors:
|
||||
bow_logits = torch.mm(probs, torch.t(one_hot_bow))
|
||||
bow_loss = -torch.log(torch.sum(bow_logits))
|
||||
loss += bow_loss
|
||||
loss_list.append(bow_loss)
|
||||
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
||||
|
||||
if loss_type == 2 or loss_type == 3:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
|
||||
curr_unpert_past = unpert_past
|
||||
curr_probs = torch.unsqueeze(probs, dim=1)
|
||||
wte = model.resize_token_embeddings()
|
||||
for _ in range(horizon_length):
|
||||
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
||||
_, curr_unpert_past, curr_all_hidden = model(
|
||||
past=curr_unpert_past,
|
||||
inputs_embeds=inputs_embeds
|
||||
)
|
||||
curr_hidden = curr_all_hidden[-1]
|
||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
|
||||
curr_hidden, dim=1)
|
||||
|
||||
prediction = classifier(new_accumulated_hidden /
|
||||
(curr_length + 1 + horizon_length))
|
||||
|
||||
label = torch.tensor(prediction.shape[0] * [class_label],
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
discrim_loss = ce_loss(prediction, label)
|
||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||
loss += discrim_loss
|
||||
loss_list.append(discrim_loss)
|
||||
|
||||
kl_loss = 0.0
|
||||
if kl_scale > 0.0:
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = (
|
||||
unpert_probs + SMALL_CONST *
|
||||
(unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||
)
|
||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
|
||||
device).detach()
|
||||
corrected_probs = probs + correction.detach()
|
||||
kl_loss = kl_scale * (
|
||||
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
||||
)
|
||||
print(' kl_loss', kl_loss.data.cpu().numpy())
|
||||
loss += kl_loss
|
||||
|
||||
loss_per_iter.append(loss.data.cpu().numpy())
|
||||
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
||||
|
||||
# compute gradients
|
||||
loss.backward()
|
||||
|
||||
# calculate gradient norms
|
||||
if grad_norms is not None and loss_type == PPLM_BOW:
|
||||
grad_norms = [
|
||||
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
else:
|
||||
grad_norms = [
|
||||
(torch.norm(p_.grad * window_mask) + SMALL_CONST)
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
|
||||
# normalize gradients
|
||||
grad = [
|
||||
-stepsize *
|
||||
(p_.grad * window_mask / grad_norms[
|
||||
index] ** gamma).data.cpu().numpy()
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
|
||||
# accumulate gradient
|
||||
grad_accumulator = list(map(add, grad, grad_accumulator))
|
||||
|
||||
# reset gradients, just to make sure
|
||||
for p_ in curr_perturbation:
|
||||
p_.grad.data.zero_()
|
||||
|
||||
# removing past from the graph
|
||||
new_past = []
|
||||
for p_ in past:
|
||||
new_past.append(p_.detach())
|
||||
past = new_past
|
||||
|
||||
# apply the accumulated perturbations to the past
|
||||
grad_accumulator = [
|
||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
||||
for p_ in grad_accumulator
|
||||
]
|
||||
pert_past = list(map(add, past, grad_accumulator))
|
||||
|
||||
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
||||
|
||||
|
||||
def get_classifier(
|
||||
name: Optional[str], class_label: Union[str, int],
|
||||
device: str
|
||||
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
||||
if name is None:
|
||||
return None, None
|
||||
|
||||
params = DISCRIMINATOR_MODELS_PARAMS[name]
|
||||
classifier = ClassificationHead(
|
||||
class_size=params['class_size'],
|
||||
embed_size=params['embed_size']
|
||||
).to(device)
|
||||
if "url" in params:
|
||||
resolved_archive_file = cached_path(params["url"])
|
||||
elif "path" in params:
|
||||
resolved_archive_file = params["path"]
|
||||
else:
|
||||
raise ValueError("Either url or path have to be specified "
|
||||
"in the discriminator model parameters")
|
||||
classifier.load_state_dict(
|
||||
torch.load(resolved_archive_file, map_location=device))
|
||||
classifier.eval()
|
||||
|
||||
if isinstance(class_label, str):
|
||||
if class_label in params["class_vocab"]:
|
||||
label_id = params["class_vocab"][class_label]
|
||||
else:
|
||||
label_id = params["default_class"]
|
||||
print("class_label {} not in class_vocab".format(class_label))
|
||||
print("available values are: {}".format(params["class_vocab"]))
|
||||
print("using default class {}".format(label_id))
|
||||
|
||||
elif isinstance(class_label, int):
|
||||
if class_label in set(params["class_vocab"].values()):
|
||||
label_id = class_label
|
||||
else:
|
||||
label_id = params["default_class"]
|
||||
print("class_label {} not in class_vocab".format(class_label))
|
||||
print("available values are: {}".format(params["class_vocab"]))
|
||||
print("using default class {}".format(label_id))
|
||||
|
||||
else:
|
||||
label_id = params["default_class"]
|
||||
|
||||
return classifier, label_id
|
||||
|
||||
|
||||
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
|
||||
List[List[List[int]]]:
|
||||
bow_indices = []
|
||||
for id_or_path in bag_of_words_ids_or_paths:
|
||||
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
|
||||
filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path])
|
||||
else:
|
||||
filepath = id_or_path
|
||||
with open(filepath, "r") as f:
|
||||
words = f.read().strip().split("\n")
|
||||
bow_indices.append(
|
||||
[tokenizer.encode(word.strip(), add_prefix_space=True) for word in
|
||||
words])
|
||||
return bow_indices
|
||||
|
||||
|
||||
def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
|
||||
if bow_indices is None:
|
||||
return None
|
||||
|
||||
one_hot_bows_vectors = []
|
||||
for single_bow in bow_indices:
|
||||
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
|
||||
single_bow = torch.tensor(single_bow).to(device)
|
||||
num_words = single_bow.shape[0]
|
||||
one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
|
||||
one_hot_bow.scatter_(1, single_bow, 1)
|
||||
one_hot_bows_vectors.append(one_hot_bow)
|
||||
return one_hot_bows_vectors
|
||||
|
||||
|
||||
def full_text_generation(
|
||||
model,
|
||||
tokenizer,
|
||||
context=None,
|
||||
num_samples=1,
|
||||
device="cuda",
|
||||
bag_of_words=None,
|
||||
discrim=None,
|
||||
class_label=None,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
**kwargs
|
||||
):
|
||||
classifier, class_id = get_classifier(
|
||||
discrim,
|
||||
class_label,
|
||||
device
|
||||
)
|
||||
|
||||
bow_indices = []
|
||||
if bag_of_words:
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
|
||||
tokenizer)
|
||||
|
||||
if bag_of_words and classifier:
|
||||
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
||||
loss_type = PPLM_BOW_DISCRIM
|
||||
|
||||
elif bag_of_words:
|
||||
loss_type = PPLM_BOW
|
||||
print("Using PPLM-BoW")
|
||||
|
||||
elif classifier is not None:
|
||||
loss_type = PPLM_DISCRIM
|
||||
print("Using PPLM-Discrim")
|
||||
|
||||
else:
|
||||
raise Exception("Specify either a bag of words or a discriminator")
|
||||
|
||||
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
context=context,
|
||||
device=device,
|
||||
length=length,
|
||||
sample=sample,
|
||||
perturb=False
|
||||
)
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pert_gen_tok_texts = []
|
||||
discrim_losses = []
|
||||
losses_in_time = []
|
||||
|
||||
for i in range(num_samples):
|
||||
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
context=context,
|
||||
device=device,
|
||||
perturb=True,
|
||||
bow_indices=bow_indices,
|
||||
classifier=classifier,
|
||||
class_label=class_id,
|
||||
loss_type=loss_type,
|
||||
length=length,
|
||||
stepsize=stepsize,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
sample=sample,
|
||||
num_iterations=num_iterations,
|
||||
grad_length=grad_length,
|
||||
horizon_length=horizon_length,
|
||||
window_length=window_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
)
|
||||
pert_gen_tok_texts.append(pert_gen_tok_text)
|
||||
if classifier is not None:
|
||||
discrim_losses.append(discrim_loss.data.cpu().numpy())
|
||||
losses_in_time.append(loss_in_time)
|
||||
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||
|
||||
|
||||
def generate_text_pplm(
|
||||
model,
|
||||
tokenizer,
|
||||
context=None,
|
||||
past=None,
|
||||
device="cuda",
|
||||
perturb=True,
|
||||
bow_indices=None,
|
||||
classifier=None,
|
||||
class_label=None,
|
||||
loss_type=0,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
):
|
||||
output_so_far = None
|
||||
if context:
|
||||
context_t = torch.tensor(context, device=device, dtype=torch.long)
|
||||
while len(context_t.shape) < 2:
|
||||
context_t = context_t.unsqueeze(0)
|
||||
output_so_far = context_t
|
||||
|
||||
# collect one hot vectors for bags of words
|
||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
|
||||
device)
|
||||
|
||||
grad_norms = None
|
||||
last = None
|
||||
unpert_discrim_loss = 0
|
||||
loss_in_time = []
|
||||
for i in trange(length, ascii=True):
|
||||
|
||||
# Get past/probs for current output, except for last word
|
||||
# Note that GPT takes 2 inputs: past + current_token
|
||||
|
||||
# run model forward to obtain unperturbed
|
||||
if past is None and output_so_far is not None:
|
||||
last = output_so_far[:, -1:]
|
||||
if output_so_far.shape[1] > 1:
|
||||
_, past, _ = model(output_so_far[:, :-1])
|
||||
|
||||
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
|
||||
unpert_last_hidden = unpert_all_hidden[-1]
|
||||
|
||||
# check if we are abowe grad max length
|
||||
if i >= grad_length:
|
||||
current_stepsize = stepsize * 0
|
||||
else:
|
||||
current_stepsize = stepsize
|
||||
|
||||
# modify the past if necessary
|
||||
if not perturb or num_iterations == 0:
|
||||
pert_past = past
|
||||
|
||||
else:
|
||||
accumulated_hidden = unpert_last_hidden[:, :-1, :]
|
||||
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
||||
|
||||
if past is not None:
|
||||
pert_past, _, grad_norms, loss_this_iter = perturb_past(
|
||||
past,
|
||||
model,
|
||||
last,
|
||||
unpert_past=unpert_past,
|
||||
unpert_logits=unpert_logits,
|
||||
accumulated_hidden=accumulated_hidden,
|
||||
grad_norms=grad_norms,
|
||||
stepsize=current_stepsize,
|
||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||
classifier=classifier,
|
||||
class_label=class_label,
|
||||
loss_type=loss_type,
|
||||
num_iterations=num_iterations,
|
||||
horizon_length=horizon_length,
|
||||
window_length=window_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
kl_scale=kl_scale,
|
||||
device=device,
|
||||
)
|
||||
loss_in_time.append(loss_this_iter)
|
||||
else:
|
||||
pert_past = past
|
||||
|
||||
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
||||
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
|
||||
if classifier is not None:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||
label = torch.tensor([class_label], device=device,
|
||||
dtype=torch.long)
|
||||
unpert_discrim_loss = ce_loss(prediction, label)
|
||||
print(
|
||||
"unperturbed discrim loss",
|
||||
unpert_discrim_loss.data.cpu().numpy()
|
||||
)
|
||||
else:
|
||||
unpert_discrim_loss = 0
|
||||
|
||||
# Fuse the modified model and original model
|
||||
if perturb:
|
||||
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
|
||||
pert_probs = ((pert_probs ** gm_scale) * (
|
||||
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
|
||||
pert_probs = top_k_filter(pert_probs, k=top_k,
|
||||
probs=True) # + SMALL_CONST
|
||||
|
||||
# rescale
|
||||
if torch.sum(pert_probs) <= 1:
|
||||
pert_probs = pert_probs / torch.sum(pert_probs)
|
||||
|
||||
else:
|
||||
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
|
||||
# sample or greedy
|
||||
if sample:
|
||||
last = torch.multinomial(pert_probs, num_samples=1)
|
||||
|
||||
else:
|
||||
_, last = torch.topk(pert_probs, k=1, dim=-1)
|
||||
|
||||
# update context/output_so_far appending the new token
|
||||
output_so_far = (
|
||||
last if output_so_far is None
|
||||
else torch.cat((output_so_far, last), dim=1)
|
||||
)
|
||||
|
||||
print(tokenizer.decode(output_so_far.tolist()[0]))
|
||||
|
||||
return output_so_far, unpert_discrim_loss, loss_in_time
|
||||
|
||||
|
||||
def set_generic_model_params(discrim_weights, discrim_meta):
|
||||
if discrim_weights is None:
|
||||
raise ValueError('When using a generic discriminator, '
|
||||
'discrim_weights need to be specified')
|
||||
if discrim_meta is None:
|
||||
raise ValueError('When using a generic discriminator, '
|
||||
'discrim_meta need to be specified')
|
||||
|
||||
with open(discrim_meta, 'r') as discrim_meta_file:
|
||||
meta = json.load(discrim_meta_file)
|
||||
meta['path'] = discrim_weights
|
||||
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
||||
|
||||
|
||||
def run_pplm_example(
|
||||
pretrained_model="gpt2-medium",
|
||||
cond_text="",
|
||||
uncond=False,
|
||||
num_samples=1,
|
||||
bag_of_words=None,
|
||||
discrim=None,
|
||||
discrim_weights=None,
|
||||
discrim_meta=None,
|
||||
class_label=-1,
|
||||
length=100,
|
||||
stepsize=0.02,
|
||||
temperature=1.0,
|
||||
top_k=10,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
seed=0,
|
||||
no_cuda=False,
|
||||
colorama=False
|
||||
):
|
||||
# set Random seed
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
# set the device
|
||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||
|
||||
if discrim == 'generic':
|
||||
set_generic_model_params(discrim_weights, discrim_meta)
|
||||
|
||||
if discrim is not None:
|
||||
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
|
||||
"pretrained_model"
|
||||
]
|
||||
print("discrim = {}, pretrained_model set "
|
||||
"to discriminator's = {}".format(discrim, pretrained_model))
|
||||
|
||||
# load pretrained model
|
||||
model = GPT2LMHeadModel.from_pretrained(
|
||||
pretrained_model,
|
||||
output_hidden_states=True
|
||||
)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# load tokenizer
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
||||
|
||||
# Freeze GPT-2 weights
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# figure out conditioning text
|
||||
if uncond:
|
||||
tokenized_cond_text = tokenizer.encode(
|
||||
[tokenizer.bos_token]
|
||||
)
|
||||
else:
|
||||
raw_text = cond_text
|
||||
while not raw_text:
|
||||
print("Did you forget to add `--cond_text`? ")
|
||||
raw_text = input("Model prompt >>> ")
|
||||
tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text)
|
||||
|
||||
print("= Prefix of sentence =")
|
||||
print(tokenizer.decode(tokenized_cond_text))
|
||||
print()
|
||||
|
||||
# generate unperturbed and perturbed texts
|
||||
|
||||
# full_text_generation returns:
|
||||
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
context=tokenized_cond_text,
|
||||
device=device,
|
||||
num_samples=num_samples,
|
||||
bag_of_words=bag_of_words,
|
||||
discrim=discrim,
|
||||
class_label=class_label,
|
||||
length=length,
|
||||
stepsize=stepsize,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
sample=sample,
|
||||
num_iterations=num_iterations,
|
||||
grad_length=grad_length,
|
||||
horizon_length=horizon_length,
|
||||
window_length=window_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
)
|
||||
|
||||
# untokenize unperturbed text
|
||||
unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])
|
||||
|
||||
print("=" * 80)
|
||||
print("= Unperturbed generated text =")
|
||||
print(unpert_gen_text)
|
||||
print()
|
||||
|
||||
generated_texts = []
|
||||
|
||||
bow_word_ids = set()
|
||||
if bag_of_words and colorama:
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
|
||||
tokenizer)
|
||||
for single_bow_list in bow_indices:
|
||||
# filtering all words in the list composed of more than 1 token
|
||||
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
|
||||
# w[0] because we are sure w has only 1 item because previous fitler
|
||||
bow_word_ids.update(w[0] for w in filtered)
|
||||
|
||||
# iterate through the perturbed texts
|
||||
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
|
||||
try:
|
||||
# untokenize unperturbed text
|
||||
if colorama:
|
||||
import colorama
|
||||
|
||||
pert_gen_text = ''
|
||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||
if word_id in bow_word_ids:
|
||||
pert_gen_text += '{}{}{}'.format(
|
||||
colorama.Fore.RED,
|
||||
tokenizer.decode([word_id]),
|
||||
colorama.Style.RESET_ALL
|
||||
)
|
||||
else:
|
||||
pert_gen_text += tokenizer.decode([word_id])
|
||||
else:
|
||||
pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
|
||||
|
||||
print("= Perturbed generated text {} =".format(i + 1))
|
||||
print(pert_gen_text)
|
||||
print()
|
||||
except:
|
||||
pass
|
||||
|
||||
# keep the prefix, perturbed seq, original seq for each index
|
||||
generated_texts.append(
|
||||
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained_model",
|
||||
"-M",
|
||||
type=str,
|
||||
default="gpt2-medium",
|
||||
help="pretrained model name or path to local checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cond_text", type=str, default="The lake",
|
||||
help="Prefix texts to condition on"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--uncond", action="store_true",
|
||||
help="Generate from end-of-text as prefix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of samples to generate from the modified latents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bag_of_words",
|
||||
"-B",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Bags of words used for PPLM-BoW. "
|
||||
"Either a BOW id (see list in code) or a filepath. "
|
||||
"Multiple BoWs separated by ;",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim",
|
||||
"-D",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=("clickbait", "sentiment", "toxicity", "generic"),
|
||||
help="Discriminator to use",
|
||||
)
|
||||
parser.add_argument('--discrim_weights', type=str, default=None,
|
||||
help='Weights for the generic discriminator')
|
||||
parser.add_argument('--discrim_meta', type=str, default=None,
|
||||
help='Meta information for the generic discriminator')
|
||||
parser.add_argument(
|
||||
"--class_label",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Class label used for the discriminator",
|
||||
)
|
||||
parser.add_argument("--length", type=int, default=100)
|
||||
parser.add_argument("--stepsize", type=float, default=0.02)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_k", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--sample", action="store_true",
|
||||
help="Generate from end-of-text as prefix"
|
||||
)
|
||||
parser.add_argument("--num_iterations", type=int, default=3)
|
||||
parser.add_argument("--grad_length", type=int, default=10000)
|
||||
parser.add_argument(
|
||||
"--window_length",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Length of past which is being optimized; "
|
||||
"0 corresponds to infinite window length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon_length",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Length of future to optimize over",
|
||||
)
|
||||
parser.add_argument("--decay", action="store_true",
|
||||
help="whether to decay or not")
|
||||
parser.add_argument("--gamma", type=float, default=1.5)
|
||||
parser.add_argument("--gm_scale", type=float, default=0.9)
|
||||
parser.add_argument("--kl_scale", type=float, default=0.01)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||
parser.add_argument("--colorama", action="store_true",
|
||||
help="colors keywords")
|
||||
|
||||
args = parser.parse_args()
|
||||
run_pplm_example(**vars(args))
|
||||
588
examples/pplm/run_pplm_discrim_train.py
Normal file
588
examples/pplm/run_pplm_discrim_train.py
Normal file
@@ -0,0 +1,588 @@
|
||||
#! /usr/bin/env python3
|
||||
# coding=utf-8
|
||||
|
||||
#Copyright (c) 2019 Uber Technologies, Inc.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
#http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim
|
||||
import torch.optim as optim
|
||||
import torch.utils.data as data
|
||||
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
||||
from torchtext import data as torchtext_data
|
||||
from torchtext import datasets
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
from pplm_classification_head import ClassificationHead
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
EPSILON = 1e-10
|
||||
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
|
||||
max_length_seq = 100
|
||||
|
||||
|
||||
|
||||
|
||||
class Discriminator(torch.nn.Module):
|
||||
"""Transformer encoder followed by a Classification Head"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
class_size,
|
||||
pretrained_model="gpt2-medium",
|
||||
cached_mode=False,
|
||||
device='cpu'
|
||||
):
|
||||
super(Discriminator, self).__init__()
|
||||
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
||||
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
|
||||
self.embed_size = self.encoder.transformer.config.hidden_size
|
||||
self.classifier_head = ClassificationHead(
|
||||
class_size=class_size,
|
||||
embed_size=self.embed_size
|
||||
)
|
||||
self.cached_mode = cached_mode
|
||||
self.device = device
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classifier_head
|
||||
|
||||
def train_custom(self):
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
self.classifier_head.train()
|
||||
|
||||
def avg_representation(self, x):
|
||||
mask = x.ne(0).unsqueeze(2).repeat(
|
||||
1, 1, self.embed_size
|
||||
).float().to(self.device).detach()
|
||||
hidden, _ = self.encoder.transformer(x)
|
||||
masked_hidden = hidden * mask
|
||||
avg_hidden = torch.sum(masked_hidden, dim=1) / (
|
||||
torch.sum(mask, dim=1).detach() + EPSILON
|
||||
)
|
||||
return avg_hidden
|
||||
|
||||
def forward(self, x):
|
||||
if self.cached_mode:
|
||||
avg_hidden = x.to(self.device)
|
||||
else:
|
||||
avg_hidden = self.avg_representation(x.to(self.device))
|
||||
|
||||
logits = self.classifier_head(avg_hidden)
|
||||
probs = F.log_softmax(logits, dim=-1)
|
||||
|
||||
return probs
|
||||
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(self, X, y):
|
||||
"""Reads source and target sequences from txt files."""
|
||||
self.X = X
|
||||
self.y = y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Returns one data pair (source and target)."""
|
||||
data = {}
|
||||
data["X"] = self.X[index]
|
||||
data["y"] = self.y[index]
|
||||
return data
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
def pad_sequences(sequences):
|
||||
lengths = [len(seq) for seq in sequences]
|
||||
|
||||
padded_sequences = torch.zeros(
|
||||
len(sequences),
|
||||
max(lengths)
|
||||
).long() # padding value = 0
|
||||
|
||||
for i, seq in enumerate(sequences):
|
||||
end = lengths[i]
|
||||
padded_sequences[i, :end] = seq[:end]
|
||||
|
||||
return padded_sequences, lengths
|
||||
|
||||
item_info = {}
|
||||
for key in data[0].keys():
|
||||
item_info[key] = [d[key] for d in data]
|
||||
|
||||
x_batch, _ = pad_sequences(item_info["X"])
|
||||
y_batch = torch.tensor(item_info["y"], dtype=torch.long)
|
||||
|
||||
return x_batch, y_batch
|
||||
|
||||
|
||||
def cached_collate_fn(data):
|
||||
item_info = {}
|
||||
for key in data[0].keys():
|
||||
item_info[key] = [d[key] for d in data]
|
||||
|
||||
x_batch = torch.cat(item_info["X"], 0)
|
||||
y_batch = torch.tensor(item_info["y"], dtype=torch.long)
|
||||
|
||||
return x_batch, y_batch
|
||||
|
||||
|
||||
def train_epoch(data_loader, discriminator, optimizer,
|
||||
epoch=0, log_interval=10, device='cpu'):
|
||||
samples_so_far = 0
|
||||
discriminator.train_custom()
|
||||
for batch_idx, (input_t, target_t) in enumerate(data_loader):
|
||||
input_t, target_t = input_t.to(device), target_t.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
output_t = discriminator(input_t)
|
||||
loss = F.nll_loss(output_t, target_t)
|
||||
loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
|
||||
samples_so_far += len(input_t)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
print(
|
||||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch + 1,
|
||||
samples_so_far, len(data_loader.dataset),
|
||||
100 * samples_so_far / len(data_loader.dataset), loss.item()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def evaluate_performance(data_loader, discriminator, device='cpu'):
|
||||
discriminator.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for input_t, target_t in data_loader:
|
||||
input_t, target_t = input_t.to(device), target_t.to(device)
|
||||
output_t = discriminator(input_t)
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
|
||||
# get the index of the max log-probability
|
||||
pred_t = output_t.argmax(dim=1, keepdim=True)
|
||||
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
|
||||
|
||||
test_loss /= len(data_loader.dataset)
|
||||
|
||||
print(
|
||||
"Performance on test set: "
|
||||
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
|
||||
test_loss, correct, len(data_loader.dataset),
|
||||
100. * correct / len(data_loader.dataset)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def predict(input_sentence, model, classes, cached=False, device='cpu'):
|
||||
input_t = model.tokenizer.encode(input_sentence)
|
||||
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
|
||||
if cached:
|
||||
input_t = model.avg_representation(input_t)
|
||||
|
||||
log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
|
||||
print("Input sentence:", input_sentence)
|
||||
print("Predictions:", ", ".join(
|
||||
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in
|
||||
zip(classes, log_probs)
|
||||
))
|
||||
|
||||
|
||||
def get_cached_data_loader(dataset, batch_size, discriminator,
|
||||
shuffle=False, device='cpu'):
|
||||
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
xs = []
|
||||
ys = []
|
||||
for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
|
||||
with torch.no_grad():
|
||||
x = x.to(device)
|
||||
avg_rep = discriminator.avg_representation(x).cpu().detach()
|
||||
avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
|
||||
xs += avg_rep_list
|
||||
ys += y.cpu().numpy().tolist()
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset=Dataset(xs, ys),
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
collate_fn=cached_collate_fn)
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
def train_discriminator(
|
||||
dataset, dataset_fp=None, pretrained_model="gpt2-medium",
|
||||
epochs=10, batch_size=64, log_interval=10,
|
||||
save_model=False, cached=False, no_cuda=False):
|
||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||
|
||||
print("Preprocessing {} dataset...".format(dataset))
|
||||
start = time.time()
|
||||
|
||||
if dataset == "SST":
|
||||
idx2class = ["positive", "negative", "very positive", "very negative",
|
||||
"neutral"]
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
).to(device)
|
||||
|
||||
text = torchtext_data.Field()
|
||||
label = torchtext_data.Field(sequential=False)
|
||||
train_data, val_data, test_data = datasets.SST.splits(
|
||||
text,
|
||||
label,
|
||||
fine_grained=True,
|
||||
train_subtrees=True,
|
||||
)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
for i in trange(len(train_data), ascii=True):
|
||||
seq = TreebankWordDetokenizer().detokenize(
|
||||
vars(train_data[i])["text"]
|
||||
)
|
||||
seq = discriminator.tokenizer.encode(seq)
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
x.append(seq)
|
||||
y.append(class2idx[vars(train_data[i])["label"]])
|
||||
train_dataset = Dataset(x, y)
|
||||
|
||||
test_x = []
|
||||
test_y = []
|
||||
for i in trange(len(test_data), ascii=True):
|
||||
seq = TreebankWordDetokenizer().detokenize(
|
||||
vars(test_data[i])["text"]
|
||||
)
|
||||
seq = discriminator.tokenizer.encode(seq)
|
||||
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
||||
test_x.append(seq)
|
||||
test_y.append(class2idx[vars(test_data[i])["label"]])
|
||||
test_dataset = Dataset(test_x, test_y)
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
"embed_size": discriminator.embed_size,
|
||||
"pretrained_model": pretrained_model,
|
||||
"class_vocab": class2idx,
|
||||
"default_class": 2,
|
||||
}
|
||||
|
||||
elif dataset == "clickbait":
|
||||
idx2class = ["non_clickbait", "clickbait"]
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
).to(device)
|
||||
|
||||
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
||||
data = []
|
||||
for i, line in enumerate(f):
|
||||
try:
|
||||
data.append(eval(line))
|
||||
except:
|
||||
print("Error evaluating line {}: {}".format(
|
||||
i, line
|
||||
))
|
||||
continue
|
||||
x = []
|
||||
y = []
|
||||
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
||||
for i, line in enumerate(tqdm(f, ascii=True)):
|
||||
try:
|
||||
d = eval(line)
|
||||
seq = discriminator.tokenizer.encode(d["text"])
|
||||
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor(
|
||||
[50256] + seq, device=device, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
continue
|
||||
x.append(seq)
|
||||
y.append(d["label"])
|
||||
except:
|
||||
print("Error evaluating / tokenizing"
|
||||
" line {}, skipping it".format(i))
|
||||
pass
|
||||
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
test_size = len(full_dataset) - train_size
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(
|
||||
full_dataset, [train_size, test_size]
|
||||
)
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
"embed_size": discriminator.embed_size,
|
||||
"pretrained_model": pretrained_model,
|
||||
"class_vocab": class2idx,
|
||||
"default_class": 1,
|
||||
}
|
||||
|
||||
elif dataset == "toxic":
|
||||
idx2class = ["non_toxic", "toxic"]
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
).to(device)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
with open("datasets/toxic/toxic_train.txt") as f:
|
||||
for i, line in enumerate(tqdm(f, ascii=True)):
|
||||
try:
|
||||
d = eval(line)
|
||||
seq = discriminator.tokenizer.encode(d["text"])
|
||||
|
||||
if len(seq) < max_length_seq:
|
||||
seq = torch.tensor(
|
||||
[50256] + seq, device=device, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
print("Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
continue
|
||||
x.append(seq)
|
||||
y.append(int(np.sum(d["label"]) > 0))
|
||||
except:
|
||||
print("Error evaluating / tokenizing"
|
||||
" line {}, skipping it".format(i))
|
||||
pass
|
||||
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
test_size = len(full_dataset) - train_size
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(
|
||||
full_dataset, [train_size, test_size]
|
||||
)
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
"embed_size": discriminator.embed_size,
|
||||
"pretrained_model": pretrained_model,
|
||||
"class_vocab": class2idx,
|
||||
"default_class": 0,
|
||||
}
|
||||
|
||||
else: # if dataset == "generic":
|
||||
# This assumes the input dataset is a TSV with the following structure:
|
||||
# class \t text
|
||||
|
||||
if dataset_fp is None:
|
||||
raise ValueError("When generic dataset is selected, "
|
||||
"dataset_fp needs to be specified aswell.")
|
||||
|
||||
classes = set()
|
||||
with open(dataset_fp) as f:
|
||||
csv_reader = csv.reader(f, delimiter="\t")
|
||||
for row in tqdm(csv_reader, ascii=True):
|
||||
if row:
|
||||
classes.add(row[0])
|
||||
|
||||
idx2class = sorted(classes)
|
||||
class2idx = {c: i for i, c in enumerate(idx2class)}
|
||||
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
).to(device)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
with open(dataset_fp) as f:
|
||||
csv_reader = csv.reader(f, delimiter="\t")
|
||||
for i, row in enumerate(tqdm(csv_reader, ascii=True)):
|
||||
if row:
|
||||
label = row[0]
|
||||
text = row[1]
|
||||
|
||||
try:
|
||||
seq = discriminator.tokenizer.encode(text)
|
||||
if (len(seq) < max_length_seq):
|
||||
seq = torch.tensor(
|
||||
[50256] + seq,
|
||||
device=device,
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
else:
|
||||
print(
|
||||
"Line {} is longer than maximum length {}".format(
|
||||
i, max_length_seq
|
||||
))
|
||||
continue
|
||||
|
||||
x.append(seq)
|
||||
y.append(class2idx[label])
|
||||
|
||||
except:
|
||||
print("Error tokenizing line {}, skipping it".format(i))
|
||||
pass
|
||||
|
||||
full_dataset = Dataset(x, y)
|
||||
train_size = int(0.9 * len(full_dataset))
|
||||
test_size = len(full_dataset) - train_size
|
||||
train_dataset, test_dataset = torch.utils.data.random_split(
|
||||
full_dataset,
|
||||
[train_size, test_size]
|
||||
)
|
||||
|
||||
discriminator_meta = {
|
||||
"class_size": len(idx2class),
|
||||
"embed_size": discriminator.embed_size,
|
||||
"pretrained_model": pretrained_model,
|
||||
"class_vocab": class2idx,
|
||||
"default_class": 0,
|
||||
}
|
||||
|
||||
end = time.time()
|
||||
print("Preprocessed {} data points".format(
|
||||
len(train_dataset) + len(test_dataset))
|
||||
)
|
||||
print("Data preprocessing took: {:.3f}s".format(end - start))
|
||||
|
||||
if cached:
|
||||
print("Building representation cache...")
|
||||
|
||||
start = time.time()
|
||||
|
||||
train_loader = get_cached_data_loader(
|
||||
train_dataset, batch_size, discriminator,
|
||||
shuffle=True, device=device
|
||||
)
|
||||
|
||||
test_loader = get_cached_data_loader(
|
||||
test_dataset, batch_size, discriminator, device=device
|
||||
)
|
||||
|
||||
end = time.time()
|
||||
print("Building representation cache took: {:.3f}s".format(end - start))
|
||||
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn)
|
||||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
if save_model:
|
||||
with open("{}_classifier_head_meta.json".format(dataset),
|
||||
"w") as meta_file:
|
||||
json.dump(discriminator_meta, meta_file)
|
||||
|
||||
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
|
||||
|
||||
for epoch in range(epochs):
|
||||
start = time.time()
|
||||
print("\nEpoch", epoch + 1)
|
||||
|
||||
train_epoch(
|
||||
discriminator=discriminator,
|
||||
data_loader=train_loader,
|
||||
optimizer=optimizer,
|
||||
epoch=epoch,
|
||||
log_interval=log_interval,
|
||||
device=device
|
||||
)
|
||||
evaluate_performance(
|
||||
data_loader=test_loader,
|
||||
discriminator=discriminator,
|
||||
device=device
|
||||
)
|
||||
|
||||
end = time.time()
|
||||
print("Epoch took: {:.3f}s".format(end - start))
|
||||
|
||||
print("\nExample prediction")
|
||||
predict(example_sentence, discriminator, idx2class,
|
||||
cached=cached, device=device)
|
||||
|
||||
if save_model:
|
||||
# torch.save(discriminator.state_dict(),
|
||||
# "{}_discriminator_{}.pt".format(
|
||||
# args.dataset, epoch + 1
|
||||
# ))
|
||||
torch.save(discriminator.get_classifier().state_dict(),
|
||||
"{}_classifier_head_epoch_{}.pt".format(dataset,
|
||||
epoch + 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a discriminator on top of GPT-2 representations")
|
||||
parser.add_argument("--dataset", type=str, default="SST",
|
||||
choices=("SST", "clickbait", "toxic", "generic"),
|
||||
help="dataset to train the discriminator on."
|
||||
"In case of generic, the dataset is expected"
|
||||
"to be a TSBV file with structure: class \\t text")
|
||||
parser.add_argument("--dataset_fp", type=str, default="",
|
||||
help="File path of the dataset to use. "
|
||||
"Needed only in case of generic datadset")
|
||||
parser.add_argument("--pretrained_model", type=str, default="gpt2-medium",
|
||||
help="Pretrained model to use as encoder")
|
||||
parser.add_argument("--epochs", type=int, default=10, metavar="N",
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--batch_size", type=int, default=64, metavar="N",
|
||||
help="input batch size for training (default: 64)")
|
||||
parser.add_argument("--log_interval", type=int, default=10, metavar="N",
|
||||
help="how many batches to wait before logging training status")
|
||||
parser.add_argument("--save_model", action="store_true",
|
||||
help="whether to save the model")
|
||||
parser.add_argument("--cached", action="store_true",
|
||||
help="whether to cache the input representations")
|
||||
parser.add_argument("--no_cuda", action="store_true",
|
||||
help="use to turn off cuda")
|
||||
args = parser.parse_args()
|
||||
|
||||
train_discriminator(**(vars(args)))
|
||||
@@ -247,7 +247,11 @@ def main():
|
||||
out = out[:, len(context_tokens):].tolist()
|
||||
for o in out:
|
||||
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
|
||||
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||
if args.stop_token:
|
||||
index = text.find(args.stop_token)
|
||||
if index == -1:
|
||||
index = None
|
||||
text = text[:index]
|
||||
|
||||
print(text)
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -47,7 +48,14 @@ from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertTokenizer)
|
||||
DistilBertTokenizer,
|
||||
AlbertConfig,
|
||||
AlbertForSequenceClassification,
|
||||
AlbertTokenizer,
|
||||
XLMRobertaConfig,
|
||||
XLMRobertaForSequenceClassification,
|
||||
XLMRobertaTokenizer,
|
||||
)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
@@ -66,7 +74,9 @@ MODEL_CLASSES = {
|
||||
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
||||
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
||||
'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
||||
'xlmroberta': (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
@@ -99,6 +109,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
if args.fp16:
|
||||
@@ -158,7 +169,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
@@ -170,15 +181,23 @@ def train(args, train_dataset, model, tokenizer):
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
logs = {}
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
eval_key = 'eval_{}'.format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||
learning_rate_scalar = scheduler.get_lr()[0]
|
||||
logs['learning_rate'] = learning_rate_scalar
|
||||
logs['loss'] = loss_scalar
|
||||
logging_loss = tr_loss
|
||||
|
||||
for key, value in logs.items():
|
||||
tb_writer.add_scalar(key, value, global_step)
|
||||
print(json.dumps({**logs, **{'step': global_step}}))
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
@@ -189,11 +208,6 @@ def train(args, train_dataset, model, tokenizer):
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.tpu:
|
||||
args.xla_model.optimizer_step(optimizer, barrier=True)
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
@@ -221,7 +235,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu eval
|
||||
@@ -294,9 +308,9 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']:
|
||||
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
features = convert_examples_to_features(examples,
|
||||
tokenizer,
|
||||
@@ -322,7 +336,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
|
||||
@@ -366,11 +380,11 @@ def main():
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
@@ -397,15 +411,6 @@ def main():
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--tpu', action='store_true',
|
||||
help="Whether to run on the TPU defined in the environment variables")
|
||||
parser.add_argument('--tpu_ip_address', type=str, default='',
|
||||
help="TPU IP address if none are set in the environment variables")
|
||||
parser.add_argument('--tpu_name', type=str, default='',
|
||||
help="TPU name if none are set in the environment variables")
|
||||
parser.add_argument('--xrt_tpu_config', type=str, default='',
|
||||
help="XRT TPU config if none are set in the environment variables")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
@@ -439,23 +444,6 @@ def main():
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
if args.tpu:
|
||||
if args.tpu_ip_address:
|
||||
os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address
|
||||
if args.tpu_name:
|
||||
os.environ["TPU_NAME"] = args.tpu_name
|
||||
if args.xrt_tpu_config:
|
||||
os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config
|
||||
|
||||
assert "TPU_IP_ADDRESS" in os.environ
|
||||
assert "TPU_NAME" in os.environ
|
||||
assert "XRT_TPU_CONFIG" in os.environ
|
||||
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
args.device = xm.xla_device()
|
||||
args.xla_model = xm
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
@@ -509,7 +497,7 @@ def main():
|
||||
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0) and not args.tpu:
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
@@ -47,7 +47,8 @@ from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
|
||||
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
||||
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
||||
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
|
||||
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
|
||||
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer,
|
||||
CamembertConfig, CamembertForMaskedLM, CamembertTokenizer)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -58,7 +59,8 @@ MODEL_CLASSES = {
|
||||
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
|
||||
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||
'camembert': (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer)
|
||||
}
|
||||
|
||||
|
||||
@@ -68,7 +70,7 @@ class TextDataset(Dataset):
|
||||
directory, filename = os.path.split(file_path)
|
||||
cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_' + filename)
|
||||
|
||||
if os.path.exists(cached_features_file):
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
with open(cached_features_file, 'rb') as handle:
|
||||
self.examples = pickle.load(handle)
|
||||
@@ -186,6 +188,13 @@ def train(args, train_dataset, model, tokenizer):
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
||||
# Load in optimizer and scheduler states
|
||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
@@ -214,13 +223,37 @@ def train(args, train_dataset, model, tokenizer):
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
# set global_step to gobal_step of last saved checkpoint from model path
|
||||
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
|
||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", global_step)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
|
||||
model_to_resize = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_resize.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
|
||||
inputs = inputs.to(args.device)
|
||||
labels = labels.to(args.device)
|
||||
@@ -268,11 +301,17 @@ def train(args, train_dataset, model, tokenizer):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
_rotate_checkpoints(args, checkpoint_prefix)
|
||||
|
||||
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
|
||||
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
|
||||
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
@@ -297,7 +336,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu evaluate
|
||||
@@ -391,7 +430,7 @@ def main():
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
@@ -431,7 +470,7 @@ def main():
|
||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm:
|
||||
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
|
||||
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||
"flag (masked language modeling).")
|
||||
if args.eval_data_file is None and args.do_eval:
|
||||
|
||||
@@ -226,7 +226,7 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu evaluate
|
||||
|
||||
@@ -37,17 +37,22 @@ from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
|
||||
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
|
||||
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer
|
||||
from transformers import CamembertConfig, CamembertForTokenClassification, CamembertTokenizer
|
||||
from transformers import XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig,
|
||||
CamembertConfig, XLMRobertaConfig)),
|
||||
())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer)
|
||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
|
||||
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
|
||||
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
@@ -125,7 +130,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||
@@ -215,7 +220,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
|
||||
from transformers.data.metrics.squad_metrics import compute_predictions_logits, compute_predictions_log_probs, squad_evaluate
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@@ -23,11 +25,9 @@ import os
|
||||
import random
|
||||
import glob
|
||||
import timeit
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
@@ -43,18 +43,12 @@ from transformers import (WEIGHTS_NAME, BertConfig,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
||||
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer,
|
||||
AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer,
|
||||
XLMConfig, XLMForQuestionAnswering, XLMTokenizer,
|
||||
)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
||||
RawResult, write_predictions,
|
||||
RawResultExtended, write_predictions_extended)
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
# You can remove it from the dependencies if you are using this script outside of the library
|
||||
# We've added it here for automated tests (see examples/test_examples.py file)
|
||||
from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -65,7 +59,8 @@ MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
|
||||
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
|
||||
'albert': (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
|
||||
}
|
||||
|
||||
def set_seed(args):
|
||||
@@ -98,14 +93,16 @@ def train(args, train_dataset, model, tokenizer):
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
@@ -128,22 +125,28 @@ def train(args, train_dataset, model, tokenizer):
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
global_step = 1
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]}
|
||||
|
||||
inputs = {
|
||||
'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]
|
||||
}
|
||||
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
||||
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5],
|
||||
'p_mask': batch[6]})
|
||||
@@ -175,8 +178,8 @@ def train(args, train_dataset, model, tokenizer):
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Log metrics
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
@@ -185,8 +188,8 @@ def train(args, train_dataset, model, tokenizer):
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
# Save model checkpoint
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
@@ -215,50 +218,72 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
|
||||
eval_sampler = SequentialSampler(dataset)
|
||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu evaluate
|
||||
if args.n_gpu > 1:
|
||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
|
||||
all_results = []
|
||||
start_time = timeit.default_timer()
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1]
|
||||
}
|
||||
inputs = {
|
||||
'input_ids': batch[0],
|
||||
'attention_mask': batch[1]
|
||||
}
|
||||
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||
|
||||
example_indices = batch[3]
|
||||
|
||||
# XLNet and XLM use more arguments for their predictions
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[4],
|
||||
'p_mask': batch[5]})
|
||||
inputs.update({'cls_index': batch[4], 'p_mask': batch[5]})
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
result = RawResultExtended(unique_id = unique_id,
|
||||
start_top_log_probs = to_list(outputs[0][i]),
|
||||
start_top_index = to_list(outputs[1][i]),
|
||||
end_top_log_probs = to_list(outputs[2][i]),
|
||||
end_top_index = to_list(outputs[3][i]),
|
||||
cls_logits = to_list(outputs[4][i]))
|
||||
|
||||
output = [to_list(output[i]) for output in outputs]
|
||||
|
||||
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
|
||||
# models only use two.
|
||||
if len(output) >= 5:
|
||||
start_logits = output[0]
|
||||
start_top_index = output[1]
|
||||
end_logits = output[2]
|
||||
end_top_index = output[3]
|
||||
cls_logits = output[4]
|
||||
|
||||
result = SquadResult(
|
||||
unique_id, start_logits, end_logits,
|
||||
start_top_index=start_top_index,
|
||||
end_top_index=end_top_index,
|
||||
cls_logits=cls_logits
|
||||
)
|
||||
|
||||
else:
|
||||
result = RawResult(unique_id = unique_id,
|
||||
start_logits = to_list(outputs[0][i]),
|
||||
end_logits = to_list(outputs[1][i]))
|
||||
start_logits, end_logits = output
|
||||
result = SquadResult(
|
||||
unique_id, start_logits, end_logits
|
||||
)
|
||||
|
||||
all_results.append(result)
|
||||
|
||||
evalTime = timeit.default_timer() - start_time
|
||||
@@ -267,63 +292,84 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
# Compute predictions
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
|
||||
if args.version_2_with_negative:
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
else:
|
||||
output_null_log_odds_file = None
|
||||
|
||||
# XLNet and XLM use a more complex post-processing procedure
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
||||
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
|
||||
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
|
||||
|
||||
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||
model.config.start_n_top, model.config.end_n_top,
|
||||
output_nbest_file, output_null_log_odds_file,
|
||||
start_n_top, end_n_top,
|
||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||
else:
|
||||
write_predictions(examples, features, all_results, args.n_best_size,
|
||||
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
|
||||
# Evaluate with the official SQuAD script
|
||||
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
|
||||
pred_file=output_prediction_file,
|
||||
na_prob_file=output_null_log_odds_file)
|
||||
results = evaluate_on_squad(evaluate_options)
|
||||
# Compute the F1 and exact scores.
|
||||
results = squad_evaluate(examples, predictions)
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
input_file = args.predict_file if evaluate else args.train_file
|
||||
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
|
||||
input_dir = args.data_dir if args.data_dir else "."
|
||||
cached_features_file = os.path.join(input_dir, 'cached_{}_{}_{}'.format(
|
||||
'dev' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length)))
|
||||
str(args.max_seq_length))
|
||||
)
|
||||
|
||||
# Init features and dataset from cache if it exists
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
features_and_dataset = torch.load(cached_features_file)
|
||||
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", input_file)
|
||||
examples = read_squad_examples(input_file=input_file,
|
||||
is_training=not evaluate,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
features = convert_examples_to_features(examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
||||
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
|
||||
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
|
||||
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
|
||||
logger.info("Creating features from dataset file at %s", input_dir)
|
||||
|
||||
if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
|
||||
try:
|
||||
import tensorflow_datasets as tfds
|
||||
except ImportError:
|
||||
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
||||
|
||||
if args.version_2_with_negative:
|
||||
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")
|
||||
|
||||
tfds_examples = tfds.load("squad")
|
||||
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
|
||||
else:
|
||||
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
||||
|
||||
if evaluate:
|
||||
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
|
||||
else:
|
||||
examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
|
||||
|
||||
features, dataset = squad_convert_examples_to_features(
|
||||
examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
return_dataset='pt'
|
||||
)
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
torch.save({"features": features, "dataset": dataset}, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
@@ -355,10 +401,6 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--train_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for training. E.g., train-v1.1.json")
|
||||
parser.add_argument("--predict_file", default=None, type=str, required=True,
|
||||
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
@@ -367,6 +409,15 @@ def main():
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str,
|
||||
help="The input data dir. Should contain the .json files for the task." +
|
||||
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
||||
parser.add_argument("--train_file", default=None, type=str,
|
||||
help="The input training file. If a data dir is specified, will look for the file there" +
|
||||
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
||||
parser.add_argument("--predict_file", default=None, type=str,
|
||||
help="The input evaluation file. If a data dir is specified, will look for the file there" +
|
||||
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
@@ -405,7 +456,7 @@ def main():
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
@@ -540,7 +591,7 @@ def main():
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
model = model_class.from_pretrained(args.output_dir, force_download=True)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
@@ -548,17 +599,23 @@ def main():
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||
|
||||
if args.do_train:
|
||||
logger.info("Loading checkpoints saved during training for evaluation")
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||
else:
|
||||
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
|
||||
checkpoints = [args.model_name_or_path]
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model = model_class.from_pretrained(checkpoint, force_download=True)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
|
||||
@@ -1,492 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2019 The HuggingFace Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning seq2seq models for sequence generation."""
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertForMaskedLM,
|
||||
BertConfig,
|
||||
PreTrainedEncoderDecoder,
|
||||
Model2Model,
|
||||
)
|
||||
|
||||
from utils_summarization import (
|
||||
CNNDailyMailDataset,
|
||||
encode_for_summarization,
|
||||
fit_to_block_size,
|
||||
build_lm_labels,
|
||||
build_mask,
|
||||
compute_token_type_ids,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
|
||||
# ------------
|
||||
# Load dataset
|
||||
# ------------
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
def collate(data, tokenizer, block_size):
|
||||
""" List of tuple as an input. """
|
||||
# remove the files with empty an story/summary, encode and fit to block
|
||||
data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data)
|
||||
data = [
|
||||
encode_for_summarization(story, summary, tokenizer) for story, summary in data
|
||||
]
|
||||
data = [
|
||||
(
|
||||
fit_to_block_size(story, block_size, tokenizer.pad_token_id),
|
||||
fit_to_block_size(summary, block_size, tokenizer.pad_token_id),
|
||||
)
|
||||
for story, summary in data
|
||||
]
|
||||
|
||||
stories = torch.tensor([story for story, summary in data])
|
||||
summaries = torch.tensor([summary for story, summary in data])
|
||||
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
|
||||
decoder_mask = build_mask(summaries, tokenizer.pad_token_id)
|
||||
lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id)
|
||||
|
||||
return (
|
||||
stories,
|
||||
summaries,
|
||||
encoder_token_type_ids,
|
||||
encoder_mask,
|
||||
decoder_mask,
|
||||
lm_labels,
|
||||
)
|
||||
|
||||
|
||||
# ----------
|
||||
# Optimizers
|
||||
# ----------
|
||||
|
||||
|
||||
class BertSumOptimizer(object):
|
||||
""" Specific optimizer for BertSum.
|
||||
|
||||
As described in [1], the authors fine-tune BertSum for abstractive
|
||||
summarization using two Adam Optimizers with different warm-up steps and
|
||||
learning rate. They also use a custom learning rate scheduler.
|
||||
|
||||
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
||||
arXiv preprint arXiv:1908.08345 (2019).
|
||||
"""
|
||||
|
||||
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
|
||||
self.encoder = model.encoder
|
||||
self.decoder = model.decoder
|
||||
self.lr = lr
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
self.optimizers = {
|
||||
"encoder": Adam(
|
||||
model.encoder.parameters(),
|
||||
lr=lr["encoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
),
|
||||
"decoder": Adam(
|
||||
model.decoder.parameters(),
|
||||
lr=lr["decoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
),
|
||||
}
|
||||
|
||||
self._step = 0
|
||||
|
||||
def _update_rate(self, stack):
|
||||
return self.lr[stack] * min(
|
||||
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-0.5)
|
||||
)
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer_decoder.zero_grad()
|
||||
self.optimizer_encoder.zero_grad()
|
||||
|
||||
def step(self):
|
||||
self._step += 1
|
||||
for stack, optimizer in self.optimizers.items():
|
||||
new_rate = self._update_rate(stack)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = new_rate
|
||||
optimizer.step()
|
||||
|
||||
|
||||
# ------------
|
||||
# Train
|
||||
# ------------
|
||||
|
||||
|
||||
def train(args, model, tokenizer):
|
||||
""" Fine-tune the pretrained model on the corpus. """
|
||||
set_seed(args)
|
||||
|
||||
# Load the data
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_dataset = load_and_cache_examples(args, tokenizer)
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
collate_fn=model_collate_fn,
|
||||
)
|
||||
|
||||
# Training schedule
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = t_total // (
|
||||
len(train_dataloader) // args.gradient_accumulation_steps + 1
|
||||
)
|
||||
else:
|
||||
t_total = (
|
||||
len(train_dataloader)
|
||||
// args.gradient_accumulation_steps
|
||||
* args.num_train_epochs
|
||||
)
|
||||
|
||||
# Prepare the optimizer
|
||||
lr = {"encoder": 0.002, "decoder": 0.2}
|
||||
warmup_steps = {"encoder": 20000, "decoder": 10000}
|
||||
optimizer = BertSumOptimizer(model, lr, warmup_steps)
|
||||
|
||||
# Train
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(
|
||||
" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size
|
||||
)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps
|
||||
# * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
model.zero_grad()
|
||||
train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True)
|
||||
|
||||
global_step = 0
|
||||
tr_loss = 0.0
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
|
||||
|
||||
source = source.to(args.device)
|
||||
target = target.to(args.device)
|
||||
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
|
||||
encoder_mask = encoder_mask.to(args.device)
|
||||
decoder_mask = decoder_mask.to(args.device)
|
||||
lm_labels = lm_labels.to(args.device)
|
||||
|
||||
model.train()
|
||||
outputs = model(
|
||||
source,
|
||||
target,
|
||||
encoder_token_type_ids=encoder_token_type_ids,
|
||||
encoder_attention_mask=encoder_mask,
|
||||
decoder_attention_mask=decoder_mask,
|
||||
decoder_lm_labels=lm_labels,
|
||||
)
|
||||
|
||||
loss = outputs[0]
|
||||
print(loss)
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss /= args.gradient_accumulation_steps
|
||||
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
# ------------
|
||||
# Train
|
||||
# ------------
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
set_seed(args)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size
|
||||
)
|
||||
|
||||
# multi-gpu evaluate
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
model.eval()
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
|
||||
|
||||
source = source.to(args.device)
|
||||
target = target.to(args.device)
|
||||
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
|
||||
encoder_mask = encoder_mask.to(args.device)
|
||||
decoder_mask = decoder_mask.to(args.device)
|
||||
lm_labels = lm_labels.to(args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
source,
|
||||
target,
|
||||
encoder_token_type_ids=encoder_token_type_ids,
|
||||
encoder_attention_mask=encoder_mask,
|
||||
decoder_attention_mask=decoder_mask,
|
||||
decoder_lm_labels=lm_labels,
|
||||
)
|
||||
lm_loss = outputs[0]
|
||||
eval_loss += lm_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
perplexity = torch.exp(torch.tensor(eval_loss))
|
||||
|
||||
result = {"perplexity": perplexity}
|
||||
|
||||
# Save the evaluation's results
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input training data file (a text file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
# Optional parameters
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_evaluate",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Run model evaluation on out-of-sample data.",
|
||||
)
|
||||
parser.add_argument("--do_train", type=bool, default=False, help="Run training.")
|
||||
parser.add_argument(
|
||||
"--do_overwrite_output_dir",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to overwrite the output dir.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint to initialize the encoder and decoder's weights with.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="bert",
|
||||
type=str,
|
||||
help="The decoder architecture to be fine-tuned.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to_cpu", default=False, type=bool, help="Whether to force training on CPU."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_train_epochs",
|
||||
default=10,
|
||||
type=int,
|
||||
help="Total number of training epochs to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_gpu_train_batch_size",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
parser.add_argument("--seed", default=42, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.do_overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Set up training device
|
||||
if args.to_cpu or not torch.cuda.is_available():
|
||||
args.device = torch.device("cpu")
|
||||
args.n_gpu = 0
|
||||
else:
|
||||
args.device = torch.device("cuda")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
|
||||
# Load pretrained model and tokenizer. The decoder's weights are randomly initialized.
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
config = BertConfig.from_pretrained(args.model_name_or_path)
|
||||
decoder_model = BertForMaskedLM(config)
|
||||
model = Model2Model.from_pretrained(
|
||||
args.model_name_or_path, decoder_model=decoder_model
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
0,
|
||||
args.device,
|
||||
args.n_gpu,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Train the model
|
||||
model.to(args.device)
|
||||
if args.do_train:
|
||||
global_step, tr_loss = train(args, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
|
||||
|
||||
# Evaluate the model
|
||||
results = {}
|
||||
if args.do_evaluate:
|
||||
checkpoints = []
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
encoder_checkpoint = os.path.join(checkpoint, "encoder")
|
||||
decoder_checkpoint = os.path.join(checkpoint, "decoder")
|
||||
model = PreTrainedEncoderDecoder.from_pretrained(
|
||||
encoder_checkpoint, decoder_checkpoint
|
||||
)
|
||||
model.to(args.device)
|
||||
results = "placeholder"
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
615
examples/run_tf_ner.py
Normal file
615
examples/run_tf_ner.py
Normal file
@@ -0,0 +1,615 @@
|
||||
# coding=utf-8
|
||||
import datetime
|
||||
import os
|
||||
import math
|
||||
import glob
|
||||
import re
|
||||
import tensorflow as tf
|
||||
import collections
|
||||
import numpy as np
|
||||
from seqeval import metrics
|
||||
import _pickle as pickle
|
||||
from absl import logging
|
||||
from transformers import TF2_WEIGHTS_NAME, BertConfig, BertTokenizer, TFBertForTokenClassification
|
||||
from transformers import RobertaConfig, RobertaTokenizer, TFRobertaForTokenClassification
|
||||
from transformers import DistilBertConfig, DistilBertTokenizer, TFDistilBertForTokenClassification
|
||||
from transformers import create_optimizer, GradientAccumulator
|
||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||
from fastprogress import master_bar, progress_bar
|
||||
from absl import flags
|
||||
from absl import app
|
||||
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
|
||||
())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer)
|
||||
}
|
||||
|
||||
|
||||
flags.DEFINE_string(
|
||||
"data_dir", None,
|
||||
"The input data dir. Should contain the .conll files (or other data files) "
|
||||
"for the task.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"model_type", None,
|
||||
"Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
|
||||
flags.DEFINE_string(
|
||||
"model_name_or_path", None,
|
||||
"Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
|
||||
flags.DEFINE_string(
|
||||
"output_dir", None,
|
||||
"The output directory where the model checkpoints will be written.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"labels", "",
|
||||
"Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"config_name", "",
|
||||
"Pretrained config name or path if not the same as model_name")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"tokenizer_name", "",
|
||||
"Pretrained tokenizer name or path if not the same as model_name")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"cache_dir", "",
|
||||
"Where do you want to store the pre-trained models downloaded from s3")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"max_seq_length", 128,
|
||||
"The maximum total input sentence length after tokenization. "
|
||||
"Sequences longer than this will be truncated, sequences shorter "
|
||||
"will be padded.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"tpu", None,
|
||||
"The Cloud TPU to use for training. This should be either the name "
|
||||
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
||||
"url.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"num_tpu_cores", 8,
|
||||
"Total number of TPU cores to use.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_train", False,
|
||||
"Whether to run training.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_eval", False,
|
||||
"Whether to run eval on the dev set.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_predict", False,
|
||||
"Whether to run predictions on the test set.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"evaluate_during_training", False,
|
||||
"Whether to run evaluation during training at each logging step.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_lower_case", False,
|
||||
"Set this flag if you are using an uncased model.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"per_device_train_batch_size", 8,
|
||||
"Batch size per GPU/CPU/TPU for training.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"per_device_eval_batch_size", 8,
|
||||
"Batch size per GPU/CPU/TPU for evaluation.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"gradient_accumulation_steps", 1,
|
||||
"Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"learning_rate", 5e-5,
|
||||
"The initial learning rate for Adam.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"weight_decay", 0.0,
|
||||
"Weight decay if we apply some.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"adam_epsilon", 1e-8,
|
||||
"Epsilon for Adam optimizer.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"max_grad_norm", 1.0,
|
||||
"Max gradient norm.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"num_train_epochs", 3,
|
||||
"Total number of training epochs to perform.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"max_steps", -1,
|
||||
"If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"warmup_steps", 0,
|
||||
"Linear warmup over warmup_steps.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"logging_steps", 50,
|
||||
"Log every X updates steps.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"save_steps", 50,
|
||||
"Save checkpoint every X updates steps.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"eval_all_checkpoints", False,
|
||||
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"no_cuda", False,
|
||||
"Avoid using CUDA when available")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"overwrite_output_dir", False,
|
||||
"Overwrite the content of the output directory")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"overwrite_cache", False,
|
||||
"Overwrite the cached training and evaluation sets")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"seed", 42,
|
||||
"random seed for initialization")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"fp16", False,
|
||||
"Whether to use 16-bit (mixed) precision instead of 32-bit")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"gpus", "0",
|
||||
"Comma separated list of gpus devices. If only one, switch to single "
|
||||
"gpu strategy, if None takes all the gpus available.")
|
||||
|
||||
|
||||
def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id):
|
||||
if args['max_steps'] > 0:
|
||||
num_train_steps = args['max_steps'] * args['gradient_accumulation_steps']
|
||||
args['num_train_epochs'] = 1
|
||||
else:
|
||||
num_train_steps = math.ceil(num_train_examples / train_batch_size) // args['gradient_accumulation_steps'] * args['num_train_epochs']
|
||||
|
||||
writer = tf.summary.create_file_writer("/tmp/mylogs")
|
||||
|
||||
with strategy.scope():
|
||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
||||
optimizer = create_optimizer(args['learning_rate'], num_train_steps, args['warmup_steps'])
|
||||
|
||||
if args['fp16']:
|
||||
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, 'dynamic')
|
||||
|
||||
loss_metric = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
|
||||
gradient_accumulator = GradientAccumulator()
|
||||
|
||||
logging.info("***** Running training *****")
|
||||
logging.info(" Num examples = %d", num_train_examples)
|
||||
logging.info(" Num Epochs = %d", args['num_train_epochs'])
|
||||
logging.info(" Instantaneous batch size per device = %d", args['per_device_train_batch_size'])
|
||||
logging.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
train_batch_size * args['gradient_accumulation_steps'])
|
||||
logging.info(" Gradient Accumulation steps = %d", args['gradient_accumulation_steps'])
|
||||
logging.info(" Total training steps = %d", num_train_steps)
|
||||
|
||||
model.summary()
|
||||
|
||||
@tf.function
|
||||
def apply_gradients():
|
||||
grads_and_vars = []
|
||||
|
||||
for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
|
||||
if gradient is not None:
|
||||
scaled_gradient = gradient / (args['n_device'] * args['gradient_accumulation_steps'])
|
||||
grads_and_vars.append((scaled_gradient, variable))
|
||||
else:
|
||||
grads_and_vars.append((gradient, variable))
|
||||
|
||||
optimizer.apply_gradients(grads_and_vars, args['max_grad_norm'])
|
||||
gradient_accumulator.reset()
|
||||
|
||||
@tf.function
|
||||
def train_step(train_features, train_labels):
|
||||
def step_fn(train_features, train_labels):
|
||||
inputs = {'attention_mask': train_features['input_mask'], 'training': True}
|
||||
|
||||
if args['model_type'] != "distilbert":
|
||||
inputs["token_type_ids"] = train_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
logits = model(train_features['input_ids'], **inputs)[0]
|
||||
logits = tf.reshape(logits, (-1, len(labels) + 1))
|
||||
active_loss = tf.reshape(train_features['input_mask'], (-1,))
|
||||
active_logits = tf.boolean_mask(logits, active_loss)
|
||||
train_labels = tf.reshape(train_labels, (-1,))
|
||||
active_labels = tf.boolean_mask(train_labels, active_loss)
|
||||
cross_entropy = loss_fct(active_labels, active_logits)
|
||||
loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
|
||||
grads = tape.gradient(loss, model.trainable_variables)
|
||||
|
||||
gradient_accumulator(grads)
|
||||
|
||||
return cross_entropy
|
||||
|
||||
per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
|
||||
mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
|
||||
|
||||
return mean_loss
|
||||
|
||||
current_time = datetime.datetime.now()
|
||||
train_iterator = master_bar(range(args['num_train_epochs']))
|
||||
global_step = 0
|
||||
logging_loss = 0.0
|
||||
|
||||
for epoch in train_iterator:
|
||||
epoch_iterator = progress_bar(train_dataset, total=num_train_steps, parent=train_iterator, display=args['n_device'] > 1)
|
||||
step = 1
|
||||
|
||||
with strategy.scope():
|
||||
for train_features, train_labels in epoch_iterator:
|
||||
loss = train_step(train_features, train_labels)
|
||||
|
||||
if step % args['gradient_accumulation_steps'] == 0:
|
||||
strategy.experimental_run_v2(apply_gradients)
|
||||
|
||||
loss_metric(loss)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:
|
||||
# Log metrics
|
||||
if args['n_device'] == 1 and args['evaluate_during_training']: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
logging.info("Eval at step " + str(global_step) + "\n" + report)
|
||||
logging.info("eval_loss: " + str(eval_loss))
|
||||
|
||||
precision = metrics.precision_score(y_true, y_pred)
|
||||
recall = metrics.recall_score(y_true, y_pred)
|
||||
f1 = metrics.f1_score(y_true, y_pred)
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("eval_loss", eval_loss, global_step)
|
||||
tf.summary.scalar("precision", precision, global_step)
|
||||
tf.summary.scalar("recall", recall, global_step)
|
||||
tf.summary.scalar("f1", f1, global_step)
|
||||
|
||||
lr = optimizer.learning_rate
|
||||
learning_rate = lr(step)
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("lr", learning_rate, global_step)
|
||||
tf.summary.scalar("loss", (loss_metric.result() - logging_loss) / args['logging_steps'], global_step)
|
||||
|
||||
logging_loss = loss_metric.result()
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("loss", loss_metric.result(), step=step)
|
||||
|
||||
if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args['output_dir'], "checkpoint-{}".format(global_step))
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
logging.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
train_iterator.child.comment = f'loss : {loss_metric.result()}'
|
||||
step += 1
|
||||
|
||||
train_iterator.write(f'loss epoch {epoch + 1}: {loss_metric.result()}')
|
||||
|
||||
loss_metric.reset_states()
|
||||
|
||||
logging.info(" Training took time = {}".format(datetime.datetime.now() - current_time))
|
||||
|
||||
|
||||
def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
|
||||
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
|
||||
eval_dataset, size = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode)
|
||||
eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
|
||||
preds = None
|
||||
num_eval_steps = math.ceil(size / eval_batch_size)
|
||||
master = master_bar(range(1))
|
||||
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args['n_device'] > 1)
|
||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
||||
loss = 0.0
|
||||
|
||||
logging.info("***** Running evaluation *****")
|
||||
logging.info(" Num examples = %d", size)
|
||||
logging.info(" Batch size = %d", eval_batch_size)
|
||||
|
||||
for eval_features, eval_labels in eval_iterator:
|
||||
inputs = {'attention_mask': eval_features['input_mask'], 'training': False}
|
||||
|
||||
if args['model_type'] != "distilbert":
|
||||
inputs["token_type_ids"] = eval_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
|
||||
|
||||
with strategy.scope():
|
||||
logits = model(eval_features['input_ids'], **inputs)[0]
|
||||
tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
|
||||
active_loss = tf.reshape(eval_features['input_mask'], (-1,))
|
||||
active_logits = tf.boolean_mask(tmp_logits, active_loss)
|
||||
tmp_eval_labels = tf.reshape(eval_labels, (-1,))
|
||||
active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
|
||||
cross_entropy = loss_fct(active_labels, active_logits)
|
||||
loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)
|
||||
|
||||
if preds is None:
|
||||
preds = logits.numpy()
|
||||
label_ids = eval_labels.numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.numpy(), axis=0)
|
||||
label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)
|
||||
|
||||
preds = np.argmax(preds, axis=2)
|
||||
y_pred = [[] for _ in range(label_ids.shape[0])]
|
||||
y_true = [[] for _ in range(label_ids.shape[0])]
|
||||
loss = loss / num_eval_steps
|
||||
|
||||
for i in range(label_ids.shape[0]):
|
||||
for j in range(label_ids.shape[1]):
|
||||
if label_ids[i, j] != pad_token_label_id:
|
||||
y_pred[i].append(labels[preds[i, j] - 1])
|
||||
y_true[i].append(labels[label_ids[i, j] - 1])
|
||||
|
||||
return y_true, y_pred, loss.numpy()
|
||||
|
||||
|
||||
def load_cache(cached_file, max_seq_length):
|
||||
name_to_features = {
|
||||
"input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
}
|
||||
|
||||
def _decode_record(record):
|
||||
example = tf.io.parse_single_example(record, name_to_features)
|
||||
features = {}
|
||||
features['input_ids'] = example['input_ids']
|
||||
features['input_mask'] = example['input_mask']
|
||||
features['segment_ids'] = example['segment_ids']
|
||||
|
||||
return features, example['label_ids']
|
||||
|
||||
d = tf.data.TFRecordDataset(cached_file)
|
||||
d = d.map(_decode_record, num_parallel_calls=4)
|
||||
count = d.reduce(0, lambda x, _: x + 1)
|
||||
|
||||
return d, count.numpy()
|
||||
|
||||
|
||||
def save_cache(features, cached_features_file):
|
||||
writer = tf.io.TFRecordWriter(cached_features_file)
|
||||
|
||||
for (ex_index, feature) in enumerate(features):
|
||||
if ex_index % 5000 == 0:
|
||||
logging.info("Writing example %d of %d" % (ex_index, len(features)))
|
||||
|
||||
def create_int_feature(values):
|
||||
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return f
|
||||
|
||||
record_feature = collections.OrderedDict()
|
||||
record_feature["input_ids"] = create_int_feature(feature.input_ids)
|
||||
record_feature["input_mask"] = create_int_feature(feature.input_mask)
|
||||
record_feature["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||
record_feature["label_ids"] = create_int_feature(feature.label_ids)
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))
|
||||
|
||||
writer.write(tf_example.SerializeToString())
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
|
||||
drop_remainder = True if args['tpu'] or mode == 'train' else False
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(args['data_dir'], "cached_{}_{}_{}.tf_record".format(mode,
|
||||
list(filter(None, args['model_name_or_path'].split("/"))).pop(),
|
||||
str(args['max_seq_length'])))
|
||||
if os.path.exists(cached_features_file) and not args['overwrite_cache']:
|
||||
logging.info("Loading features from cached file %s", cached_features_file)
|
||||
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
|
||||
else:
|
||||
logging.info("Creating features from dataset file at %s", args['data_dir'])
|
||||
examples = read_examples_from_file(args['data_dir'], mode)
|
||||
features = convert_examples_to_features(examples, labels, args['max_seq_length'], tokenizer,
|
||||
cls_token_at_end=bool(args['model_type'] in ["xlnet"]),
|
||||
# xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if args['model_type'] in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(args['model_type'] in ["roberta"]),
|
||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(args['model_type'] in ["xlnet"]),
|
||||
# pad on the left for xlnet
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=4 if args['model_type'] in ["xlnet"] else 0,
|
||||
pad_token_label_id=pad_token_label_id
|
||||
)
|
||||
logging.info("Saving features into cached file %s", cached_features_file)
|
||||
save_cache(features, cached_features_file)
|
||||
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
|
||||
|
||||
if mode == 'train':
|
||||
dataset = dataset.repeat()
|
||||
dataset = dataset.shuffle(buffer_size=8192, seed=args['seed'])
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder)
|
||||
dataset = dataset.prefetch(buffer_size=batch_size)
|
||||
|
||||
return dataset, size
|
||||
|
||||
|
||||
def main(_):
|
||||
logging.set_verbosity(logging.INFO)
|
||||
args = flags.FLAGS.flag_values_dict()
|
||||
|
||||
if os.path.exists(args['output_dir']) and os.listdir(
|
||||
args['output_dir']) and args['do_train'] and not args['overwrite_output_dir']:
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args['output_dir']))
|
||||
|
||||
if args['fp16']:
|
||||
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
|
||||
|
||||
if args['tpu']:
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args['tpu'])
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
||||
args['n_device'] = args['num_tpu_cores']
|
||||
elif len(args['gpus'].split(',')) > 1:
|
||||
args['n_device'] = len([f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
|
||||
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
|
||||
elif args['no_cuda']:
|
||||
args['n_device'] = 1
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||
else:
|
||||
args['n_device'] = len(args['gpus'].split(','))
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args['gpus'].split(',')[0])
|
||||
|
||||
logging.warning("n_device: %s, distributed training: %s, 16-bits training: %s",
|
||||
args['n_device'], bool(args['n_device'] > 1), args['fp16'])
|
||||
|
||||
labels = get_labels(args['labels'])
|
||||
num_labels = len(labels) + 1
|
||||
pad_token_label_id = 0
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args['model_type']]
|
||||
config = config_class.from_pretrained(args['config_name'] if args['config_name'] else args['model_name_or_path'],
|
||||
num_labels=num_labels,
|
||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
||||
|
||||
logging.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args['do_train']:
|
||||
tokenizer = tokenizer_class.from_pretrained(args['tokenizer_name'] if args['tokenizer_name'] else args['model_name_or_path'],
|
||||
do_lower_case=args['do_lower_case'],
|
||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
||||
|
||||
with strategy.scope():
|
||||
model = model_class.from_pretrained(args['model_name_or_path'],
|
||||
from_pt=bool(".bin" in args['model_name_or_path']),
|
||||
config=config,
|
||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
||||
model.layers[-1].activation = tf.keras.activations.softmax
|
||||
|
||||
train_batch_size = args['per_device_train_batch_size'] * args['n_device']
|
||||
train_dataset, num_train_examples = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train")
|
||||
train_dataset = strategy.experimental_distribute_dataset(train_dataset)
|
||||
train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id)
|
||||
|
||||
if not os.path.exists(args['output_dir']):
|
||||
os.makedirs(args['output_dir'])
|
||||
|
||||
logging.info("Saving model to %s", args['output_dir'])
|
||||
|
||||
model.save_pretrained(args['output_dir'])
|
||||
tokenizer.save_pretrained(args['output_dir'])
|
||||
|
||||
# Evaluation
|
||||
if args['do_eval']:
|
||||
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
|
||||
checkpoints = []
|
||||
results = []
|
||||
|
||||
if args['eval_all_checkpoints']:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + "/**/" + TF2_WEIGHTS_NAME, recursive=True), key=lambda f: int(''.join(filter(str.isdigit, f)) or -1)))
|
||||
|
||||
logging.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
if len(checkpoints) == 0:
|
||||
checkpoints.append(args['output_dir'])
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
||||
|
||||
with strategy.scope():
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
|
||||
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
if global_step:
|
||||
results.append({global_step + "_report": report, global_step + "_loss": eval_loss})
|
||||
|
||||
output_eval_file = os.path.join(args['output_dir'], "eval_results.txt")
|
||||
|
||||
with tf.io.gfile.GFile(output_eval_file, "w") as writer:
|
||||
for res in results:
|
||||
for key, val in res.items():
|
||||
if "loss" in key:
|
||||
logging.info(key + " = " + str(val))
|
||||
writer.write(key + " = " + str(val))
|
||||
writer.write("\n")
|
||||
else:
|
||||
logging.info(key)
|
||||
logging.info("\n" + report)
|
||||
writer.write(key + "\n")
|
||||
writer.write(report)
|
||||
writer.write("\n")
|
||||
|
||||
if args['do_predict']:
|
||||
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
|
||||
model = model_class.from_pretrained(args['output_dir'])
|
||||
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
|
||||
predict_dataset, _ = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test")
|
||||
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
|
||||
output_test_results_file = os.path.join(args['output_dir'], "test_results.txt")
|
||||
output_test_predictions_file = os.path.join(args['output_dir'], "test_predictions.txt")
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
logging.info("\n" + report)
|
||||
|
||||
writer.write(report)
|
||||
writer.write("\n\nloss = " + str(pred_loss))
|
||||
|
||||
with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
|
||||
with tf.io.gfile.GFile(os.path.join(args['data_dir'], "test.txt"), "r") as f:
|
||||
example_id = 0
|
||||
|
||||
for line in f:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
writer.write(line)
|
||||
|
||||
if not y_pred[example_id]:
|
||||
example_id += 1
|
||||
elif y_pred[example_id]:
|
||||
output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
|
||||
writer.write(output_line)
|
||||
else:
|
||||
logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("data_dir")
|
||||
flags.mark_flag_as_required("output_dir")
|
||||
flags.mark_flag_as_required("model_name_or_path")
|
||||
flags.mark_flag_as_required("model_type")
|
||||
app.run(main)
|
||||
515
examples/run_xnli.py
Normal file
515
examples/run_xnli.py
Normal file
@@ -0,0 +1,515 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning multi-lingual models on XNLI (Bert, DistilBERT, XLM).
|
||||
Adapted from `examples/run_glue.py`"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (WEIGHTS_NAME,
|
||||
BertConfig, BertForSequenceClassification, BertTokenizer,
|
||||
XLMConfig, XLMForSequenceClassification, XLMTokenizer,
|
||||
DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from transformers import xnli_compute_metrics as compute_metrics
|
||||
from transformers import xnli_output_modes as output_modes
|
||||
from transformers import xnli_processors as processors
|
||||
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'labels': batch[3]}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert'] else None # XLM and DistilBERT don't use segment_ids
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
# Log metrics
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
eval_task_names = (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir,)
|
||||
|
||||
results = {}
|
||||
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
||||
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'labels': batch[3]}
|
||||
if args.model_type != 'distilbert':
|
||||
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert'] else None # XLM and DistilBERT don't use segment_ids
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if args.output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
else:
|
||||
raise ValueError('No other `output_mode` for XNLI.')
|
||||
result = compute_metrics(eval_task, preds, out_label_ids)
|
||||
results.update(result)
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
if args.local_rank not in [-1, 0] and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
processor = processors[task](language=args.language, train_language=args.train_language)
|
||||
output_mode = output_modes[task]
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}_{}'.format(
|
||||
'test' if evaluate else 'train',
|
||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
||||
str(args.max_seq_length),
|
||||
str(task),
|
||||
str(args.train_language if (not evaluate and args.train_language is not None) else args.language)))
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
examples = processor.get_test_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
features = convert_examples_to_features(examples,
|
||||
tokenizer,
|
||||
label_list=label_list,
|
||||
max_length=args.max_seq_length,
|
||||
output_mode=output_mode,
|
||||
pad_on_left=False,
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=0,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
if output_mode == "classification":
|
||||
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
else:
|
||||
raise ValueError('No other `output_mode` for XNLI.')
|
||||
|
||||
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--language", default=None, type=str, required=True,
|
||||
help="Evaluation language. Also train language if `train_language` is set to None.")
|
||||
parser.add_argument("--train_language", default=None, type=str,
|
||||
help="Train language if is different of the evaluation language.")
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the test set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="For distributed training: local_rank")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Prepare XNLI task
|
||||
args.task_name = 'xnli'
|
||||
if args.task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (args.task_name))
|
||||
processor = processors[args.task_name](language=args.language, train_language=args.train_language)
|
||||
args.output_mode = output_modes[args.task_name]
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
61
examples/summarization/README.md
Normal file
61
examples/summarization/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Text Summarization with Pretrained Encoders
|
||||
|
||||
This folder contains part of the code necessary to reproduce the results on abstractive summarization from the article [Text Summarization with Pretrained Encoders](https://arxiv.org/pdf/1908.08345.pdf) by [Yang Liu](https://nlp-yang.github.io/) and [Mirella Lapata](https://homepages.inf.ed.ac.uk/mlap/). It can also be used to summarize any document.
|
||||
|
||||
The original code can be found on the Yang Liu's [github repository](https://github.com/nlpyang/PreSumm).
|
||||
|
||||
The model is loaded with the pre-trained weights for the abstractive summarization model trained on the CNN/Daily Mail dataset with an extractive and then abstractive tasks.
|
||||
|
||||
## Setup
|
||||
|
||||
```
|
||||
git clone https://github.com/huggingface/transformers && cd transformers
|
||||
pip install [--editable] .
|
||||
pip install nltk py-rouge
|
||||
cd examples/summarization
|
||||
```
|
||||
|
||||
## Reproduce the authors' results on ROUGE
|
||||
|
||||
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
|
||||
|
||||
```bash
|
||||
tar -xvf cnn_stories.tgz && tar -xvf dailymail_stories.tgz
|
||||
```
|
||||
|
||||
And move all the stories to the same folder. We will refer as `$DATA_PATH` the path to where you uncompressed both archive. Then run the following in the same folder as `run_summarization.py`:
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
--beam_size 5 \
|
||||
--alpha 0.95 \
|
||||
--block_trigram true \
|
||||
--compute_rouge true
|
||||
```
|
||||
|
||||
The scripts executes on GPU if one is available and if `no_cuda` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
|
||||
|
||||
## Summarize any text
|
||||
|
||||
Put the documents that you would like to summarize in a folder (the path to which is referred to as `$DATA_PATH` below) and run the following in the same folder as `run_summarization.py`:
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
--beam_size 5 \
|
||||
--alpha 0.95 \
|
||||
--block_trigram true \
|
||||
```
|
||||
|
||||
You may want to play around with `min_length`, `max_length` and `alpha` to suit your use case. If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` and tell it where to fetch the reference summaries.
|
||||
99
examples/summarization/configuration_bertabs.py
Normal file
99
examples/summarization/configuration_bertabs.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BertAbs configuration """
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BERTABS_FINETUNED_CONFIG_MAP = {
|
||||
"bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json",
|
||||
}
|
||||
|
||||
|
||||
class BertAbsConfig(PretrainedConfig):
|
||||
r""" Class to store the configuration of the BertAbs model.
|
||||
|
||||
Arguments:
|
||||
vocab_size: int
|
||||
Number of tokens in the vocabulary.
|
||||
max_pos: int
|
||||
The maximum sequence length that this model will be used with.
|
||||
enc_layer: int
|
||||
The numner of hidden layers in the Transformer encoder.
|
||||
enc_hidden_size: int
|
||||
The size of the encoder's layers.
|
||||
enc_heads: int
|
||||
The number of attention heads for each attention layer in the encoder.
|
||||
enc_ff_size: int
|
||||
The size of the encoder's feed-forward layers.
|
||||
enc_dropout: int
|
||||
The dropout probabilitiy for all fully connected layers in the
|
||||
embeddings, layers, pooler and also the attention probabilities in
|
||||
the encoder.
|
||||
dec_layer: int
|
||||
The numner of hidden layers in the decoder.
|
||||
dec_hidden_size: int
|
||||
The size of the decoder's layers.
|
||||
dec_heads: int
|
||||
The number of attention heads for each attention layer in the decoder.
|
||||
dec_ff_size: int
|
||||
The size of the decoder's feed-forward layers.
|
||||
dec_dropout: int
|
||||
The dropout probabilitiy for all fully connected layers in the
|
||||
embeddings, layers, pooler and also the attention probabilities in
|
||||
the decoder.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
max_pos=512,
|
||||
enc_layers=6,
|
||||
enc_hidden_size=512,
|
||||
enc_heads=8,
|
||||
enc_ff_size=512,
|
||||
enc_dropout=0.2,
|
||||
dec_layers=6,
|
||||
dec_hidden_size=768,
|
||||
dec_heads=8,
|
||||
dec_ff_size=2048,
|
||||
dec_dropout=0.2,
|
||||
**kwargs,
|
||||
):
|
||||
super(BertAbsConfig, self).__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_pos = max_pos
|
||||
|
||||
self.enc_layers = enc_layers
|
||||
self.enc_hidden_size = enc_hidden_size
|
||||
self.enc_heads = enc_heads
|
||||
self.enc_ff_size = enc_ff_size
|
||||
self.enc_dropout = enc_dropout
|
||||
|
||||
self.dec_layers = dec_layers
|
||||
self.dec_hidden_size = dec_hidden_size
|
||||
self.dec_heads = dec_heads
|
||||
self.dec_ff_size = dec_ff_size
|
||||
self.dec_dropout = dec_dropout
|
||||
@@ -0,0 +1,163 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Convert BertExtAbs's checkpoints.
|
||||
|
||||
The script looks like it is doing something trivial but it is not. The "weights"
|
||||
proposed by the authors are actually the entire model pickled. We need to load
|
||||
the model within the original codebase to be able to only save its `state_dict`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||
from model_bertabs import BertAbsSummarizer
|
||||
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
||||
|
||||
|
||||
BertAbsConfig = namedtuple(
|
||||
"BertAbsConfig",
|
||||
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
|
||||
)
|
||||
|
||||
|
||||
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
""" Copy/paste and tweak the pre-trained weights provided by the creators
|
||||
of BertAbs for the internal architecture.
|
||||
"""
|
||||
|
||||
# Instantiate the authors' model with the pre-trained weights
|
||||
config = BertAbsConfig(
|
||||
temp_dir=".",
|
||||
finetune_bert=False,
|
||||
large=False,
|
||||
share_emb=True,
|
||||
use_bert_emb=False,
|
||||
encoder="bert",
|
||||
max_pos=512,
|
||||
enc_layers=6,
|
||||
enc_hidden_size=512,
|
||||
enc_heads=8,
|
||||
enc_ff_size=512,
|
||||
enc_dropout=0.2,
|
||||
dec_layers=6,
|
||||
dec_hidden_size=768,
|
||||
dec_heads=8,
|
||||
dec_ff_size=2048,
|
||||
dec_dropout=0.2,
|
||||
)
|
||||
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
|
||||
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
|
||||
original.eval()
|
||||
|
||||
new_model = BertAbsSummarizer(config, torch.device("cpu"))
|
||||
new_model.eval()
|
||||
|
||||
# -------------------
|
||||
# Convert the weights
|
||||
# -------------------
|
||||
|
||||
logging.info("convert the model")
|
||||
new_model.bert.load_state_dict(original.bert.state_dict())
|
||||
new_model.decoder.load_state_dict(original.decoder.state_dict())
|
||||
new_model.generator.load_state_dict(original.generator.state_dict())
|
||||
|
||||
# ----------------------------------
|
||||
# Make sure the outpus are identical
|
||||
# ----------------------------------
|
||||
|
||||
logging.info("Make sure that the models' outputs are identical")
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# prepare the model inputs
|
||||
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
|
||||
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
|
||||
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
|
||||
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
|
||||
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
|
||||
|
||||
# failsafe to make sure the weights reset does not affect the
|
||||
# loaded weights.
|
||||
assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0
|
||||
|
||||
# forward pass
|
||||
src = encoder_input_ids
|
||||
tgt = decoder_input_ids
|
||||
segs = token_type_ids = None
|
||||
clss = None
|
||||
mask_src = encoder_attention_mask = None
|
||||
mask_tgt = decoder_attention_mask = None
|
||||
mask_cls = None
|
||||
|
||||
# The original model does not apply the geneator layer immediatly but rather in
|
||||
# the beam search (where it combines softmax + linear layer). Since we already
|
||||
# apply the softmax in our generation process we only apply the linear layer here.
|
||||
# We make sure that the outputs of the full stack are identical
|
||||
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
||||
output_original_generator = original.generator(output_original_model)
|
||||
|
||||
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
|
||||
output_converted_generator = new_model.generator(output_converted_model)
|
||||
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
||||
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item()
|
||||
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||
|
||||
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
|
||||
if are_identical:
|
||||
logging.info("all weights are equal up to 1e-3")
|
||||
else:
|
||||
raise ValueError("the weights are different. The new model is likely different from the original one.")
|
||||
|
||||
# The model has been saved with torch.save(model) and this is bound to the exact
|
||||
# directory structure. We save the state_dict instead.
|
||||
logging.info("saving the model's state dictionary")
|
||||
torch.save(new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bertabs_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path the official PyTorch dump.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_bertabs_checkpoints(
|
||||
args.bertabs_checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
)
|
||||
1161
examples/summarization/modeling_bertabs.py
Normal file
1161
examples/summarization/modeling_bertabs.py
Normal file
File diff suppressed because it is too large
Load Diff
9
examples/summarization/requirements.txt
Normal file
9
examples/summarization/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
# progress bars in model download and training scripts
|
||||
tqdm
|
||||
# Accessing files from S3 directly.
|
||||
boto3
|
||||
# Used for downloading models over HTTP
|
||||
requests
|
||||
# For ROUGE
|
||||
nltk
|
||||
py-rouge
|
||||
344
examples/summarization/run_summarization.py
Normal file
344
examples/summarization/run_summarization.py
Normal file
@@ -0,0 +1,344 @@
|
||||
#! /usr/bin/python3
|
||||
import argparse
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from modeling_bertabs import BertAbs, build_predictor
|
||||
|
||||
from utils_summarization import (
|
||||
SummarizationDataset,
|
||||
encode_for_summarization,
|
||||
build_mask,
|
||||
fit_to_block_size,
|
||||
compute_token_type_ids,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
Batch = namedtuple(
|
||||
"Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"]
|
||||
)
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
|
||||
model = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
symbols = {
|
||||
"BOS": tokenizer.vocab["[unused0]"],
|
||||
"EOS": tokenizer.vocab["[unused1]"],
|
||||
"PAD": tokenizer.vocab["[PAD]"],
|
||||
}
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries = []
|
||||
generated_summaries = []
|
||||
|
||||
import rouge
|
||||
import nltk
|
||||
nltk.download('punkt')
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=['rouge-n', 'rouge-l'],
|
||||
max_n=2,
|
||||
limit_length=True,
|
||||
length_limit=args.beam_size,
|
||||
length_limit_type='words',
|
||||
apply_avg=True,
|
||||
apply_best=False,
|
||||
alpha=0.5, # Default F1_score
|
||||
weight_factor=1.2,
|
||||
stemming=True,
|
||||
)
|
||||
|
||||
# these (unused) arguments are defined to keep the compatibility
|
||||
# with the legacy code and will be deleted in a next iteration.
|
||||
args.result_path = ""
|
||||
args.temp_dir = ""
|
||||
|
||||
data_iterator = build_data_iterator(args, tokenizer)
|
||||
predictor = build_predictor(args, tokenizer, symbols, model)
|
||||
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Number examples = %d", len(data_iterator.dataset))
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info("")
|
||||
logger.info("***** Beam Search parameters *****")
|
||||
logger.info(" Beam size = %d", args.beam_size)
|
||||
logger.info(" Minimum length = %d", args.min_length)
|
||||
logger.info(" Maximum length = %d", args.max_length)
|
||||
logger.info(" Alpha (length penalty) = %.2f", args.alpha)
|
||||
logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT"))
|
||||
|
||||
for batch in tqdm(data_iterator):
|
||||
batch_data = predictor.translate_batch(batch)
|
||||
translations = predictor.from_batch(batch_data)
|
||||
summaries = [format_summary(t) for t in translations]
|
||||
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries += batch.tgt_str
|
||||
generated_summaries += summaries
|
||||
|
||||
if args.compute_rouge:
|
||||
scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries)
|
||||
str_scores = format_rouge_scores(scores)
|
||||
save_rouge_scores(str_scores)
|
||||
print(str_scores)
|
||||
|
||||
|
||||
def save_summaries(summaries, path, original_document_name):
|
||||
""" Write the summaries in fies that are prefixed by the original
|
||||
files' name with the `_summary` appended.
|
||||
|
||||
Attributes:
|
||||
original_document_names: List[string]
|
||||
Name of the document that was summarized.
|
||||
path: string
|
||||
Path were the summaries will be written
|
||||
summaries: List[string]
|
||||
The summaries that we produced.
|
||||
"""
|
||||
for summary, document_name in zip(summaries, original_document_name):
|
||||
# Prepare the summary file's name
|
||||
if "." in document_name:
|
||||
bare_document_name = ".".join(document_name.split(".")[:-1])
|
||||
extension = document_name.split(".")[-1]
|
||||
name = bare_document_name + "_summary." + extension
|
||||
else:
|
||||
name = document_name + "_summary"
|
||||
|
||||
file_path = os.path.join(path, name)
|
||||
with open(file_path, "w") as output:
|
||||
output.write(summary)
|
||||
|
||||
|
||||
def format_summary(translation):
|
||||
""" Transforms the output of the `from_batch` function
|
||||
into nicely formatted summaries.
|
||||
"""
|
||||
raw_summary, _, _ = translation
|
||||
summary = (
|
||||
raw_summary.replace("[unused0]", "")
|
||||
.replace("[unused3]", "")
|
||||
.replace("[PAD]", "")
|
||||
.replace("[unused1]", "")
|
||||
.replace(r" +", " ")
|
||||
.replace(" [unused2] ", ". ")
|
||||
.replace("[unused2]", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def format_rouge_scores(scores):
|
||||
return """\n
|
||||
****** ROUGE SCORES ******
|
||||
|
||||
** ROUGE 1
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE 2
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE L
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}""".format(
|
||||
scores['rouge-1']['f'],
|
||||
scores['rouge-1']['p'],
|
||||
scores['rouge-1']['r'],
|
||||
scores['rouge-2']['f'],
|
||||
scores['rouge-2']['p'],
|
||||
scores['rouge-2']['r'],
|
||||
scores['rouge-l']['f'],
|
||||
scores['rouge-l']['p'],
|
||||
scores['rouge-l']['r'],
|
||||
)
|
||||
|
||||
|
||||
def save_rouge_scores(str_scores):
|
||||
with open("rouge_scores.txt", "w") as output:
|
||||
output.write(str_scores)
|
||||
|
||||
|
||||
#
|
||||
# LOAD the dataset
|
||||
#
|
||||
|
||||
|
||||
def build_data_iterator(args, tokenizer):
|
||||
dataset = load_and_cache_examples(args, tokenizer)
|
||||
sampler = SequentialSampler(dataset)
|
||||
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
|
||||
iterator = DataLoader(
|
||||
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
return iterator
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = SummarizationDataset(args.documents_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
def collate(data, tokenizer, block_size, device):
|
||||
""" Collate formats the data passed to the data loader.
|
||||
|
||||
In particular we tokenize the data batch after batch to avoid keeping them
|
||||
all in memory. We output the data as a namedtuple to fit the original BertAbs's
|
||||
API.
|
||||
"""
|
||||
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
|
||||
names = [name for name, _, _ in data]
|
||||
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||
|
||||
encoded_text = [
|
||||
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
|
||||
]
|
||||
encoded_stories = torch.tensor(
|
||||
[
|
||||
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
|
||||
for story, _ in encoded_text
|
||||
]
|
||||
)
|
||||
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||
|
||||
batch = Batch(
|
||||
document_names=names,
|
||||
batch_size=len(encoded_stories),
|
||||
src=encoded_stories.to(device),
|
||||
segs=encoder_token_type_ids.to(device),
|
||||
mask_src=encoder_mask.to(device),
|
||||
tgt_str=summaries,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def decode_summary(summary_tokens, tokenizer):
|
||||
""" Decode the summary and return it in a format
|
||||
suitable for evaluation.
|
||||
"""
|
||||
summary_tokens = summary_tokens.to("cpu").numpy()
|
||||
summary = tokenizer.decode(summary_tokens)
|
||||
sentences = summary.split(".")
|
||||
sentences = [s + "." for s in sentences]
|
||||
return sentences
|
||||
|
||||
|
||||
def main():
|
||||
""" The main function defines the interface with the users.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--documents_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The folder where the documents to summarize are located.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summaries_output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compute_rouge",
|
||||
default=False,
|
||||
type=bool,
|
||||
required=False,
|
||||
help="Compute the ROUGE metrics during evaluation. Only available for the CNN/DailyMail dataset.",
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--no_cuda",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether to force the execution on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
# BEAM SEARCH arguments
|
||||
parser.add_argument(
|
||||
"--min_length",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Minimum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
default=200,
|
||||
type=int,
|
||||
help="Maixmum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=5,
|
||||
type=int,
|
||||
help="The number of beams to start with for each example.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
default=0.95,
|
||||
type=float,
|
||||
help="The value of alpha for the length penalty in the beam search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_trigram",
|
||||
default=True,
|
||||
type=bool,
|
||||
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Select device (distibuted not available)
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
|
||||
# Check the existence of directories
|
||||
if not args.summaries_output_dir:
|
||||
args.summaries_output_dir = args.documents_dir
|
||||
|
||||
if not documents_dir_is_valid(args.documents_dir):
|
||||
raise FileNotFoundError(
|
||||
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
|
||||
)
|
||||
os.makedirs(args.summaries_output_dir, exist_ok=True)
|
||||
|
||||
evaluate(args)
|
||||
|
||||
|
||||
def documents_dir_is_valid(path):
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
|
||||
file_list = os.listdir(path)
|
||||
if len(file_list) == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -10,9 +10,14 @@ from torch.utils.data import Dataset
|
||||
# ------------
|
||||
|
||||
|
||||
class CNNDailyMailDataset(Dataset):
|
||||
class SummarizationDataset(Dataset):
|
||||
""" Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
The class will process the documents that are located in the specified
|
||||
folder. The preprocessing will work on any document that is reasonably
|
||||
formatted. On the CNN/DailyMail dataset it will extract both the story
|
||||
and the summary.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||
@@ -25,33 +30,33 @@ class CNNDailyMailDataset(Dataset):
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, prefix="train", data_dir=""):
|
||||
assert os.path.isdir(data_dir)
|
||||
self.tokenizer = tokenizer
|
||||
def __init__(self, path="", prefix="train"):
|
||||
""" We initialize the class by listing all the documents to summarize.
|
||||
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
|
||||
"""
|
||||
assert os.path.isdir(path)
|
||||
|
||||
# We initialize the class by listing all the files that contain
|
||||
# stories and summaries. Files are not read in memory given
|
||||
# the size of the corpus.
|
||||
self.stories_path = []
|
||||
datasets = ("cnn", "dailymail")
|
||||
for dataset in datasets:
|
||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||
story_filenames_list = os.listdir(path_to_stories)
|
||||
for story_filename in story_filenames_list:
|
||||
path_to_story = os.path.join(path_to_stories, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
self.stories_path.append(path_to_story)
|
||||
self.documents = []
|
||||
story_filenames_list = os.listdir(path)
|
||||
for story_filename in story_filenames_list:
|
||||
if "summary" in story_filename:
|
||||
continue
|
||||
path_to_story = os.path.join(path, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
self.documents.append(path_to_story)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.stories_path)
|
||||
""" Returns the number of documents. """
|
||||
return len(self.documents)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
story_path = self.stories_path[idx]
|
||||
with open(story_path, encoding="utf-8") as source:
|
||||
document_path = self.documents[idx]
|
||||
document_name = document_path.split("/")[-1]
|
||||
with open(document_path, encoding="utf-8") as source:
|
||||
raw_story = source.read()
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
return story_lines, summary_lines
|
||||
return document_name, story_lines, summary_lines
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
@@ -81,7 +86,7 @@ def process_story(raw_story):
|
||||
story_lines.append(element)
|
||||
except IndexError:
|
||||
# if "@highlight" is absent from the file we pop
|
||||
# all elements until there is None.
|
||||
# all elements until there is None, raising an exception.
|
||||
return story_lines, []
|
||||
|
||||
# gather summary lines
|
||||
@@ -104,31 +109,22 @@ def _add_missing_period(line):
|
||||
# --------------------------
|
||||
|
||||
|
||||
def fit_to_block_size(sequence, block_size, pad_token):
|
||||
def fit_to_block_size(sequence, block_size, pad_token_id):
|
||||
""" Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter than the block size we pad it with -1 ids
|
||||
which correspond to padding tokens.
|
||||
If the sequence is shorter we append padding token to the right of the sequence.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
sequence.extend([pad_token] * (block_size - len(sequence)))
|
||||
sequence.extend([pad_token_id] * (block_size - len(sequence)))
|
||||
return sequence
|
||||
|
||||
|
||||
def build_lm_labels(sequence, pad_token):
|
||||
""" Padding token, encoded as 0, are represented by the value -1 so they
|
||||
are not taken into account in the loss computation. """
|
||||
padded = sequence.clone()
|
||||
padded[padded == pad_token] = -1
|
||||
return padded
|
||||
|
||||
|
||||
def build_mask(sequence, pad_token):
|
||||
def build_mask(sequence, pad_token_id):
|
||||
""" Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1. """
|
||||
mask = torch.ones_like(sequence)
|
||||
idx_pad_tokens = sequence == pad_token
|
||||
idx_pad_tokens = sequence == pad_token_id
|
||||
mask[idx_pad_tokens] = 0
|
||||
return mask
|
||||
|
||||
@@ -138,18 +134,11 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||
sentences.
|
||||
"""
|
||||
story_lines_token_ids = [
|
||||
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
|
||||
for line in story_lines
|
||||
]
|
||||
summary_lines_token_ids = [
|
||||
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
|
||||
for line in summary_lines
|
||||
]
|
||||
|
||||
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
||||
story_token_ids = [
|
||||
token for sentence in story_lines_token_ids for token in sentence
|
||||
]
|
||||
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
||||
summary_token_ids = [
|
||||
token for sentence in summary_lines_token_ids for token in sentence
|
||||
]
|
||||
@@ -174,7 +163,7 @@ def compute_token_type_ids(batch, separator_token_id):
|
||||
"""
|
||||
batch_embeddings = []
|
||||
for sequence in batch:
|
||||
sentence_num = 0
|
||||
sentence_num = -1
|
||||
embeddings = []
|
||||
for s in sequence:
|
||||
if s == separator_token_id:
|
||||
@@ -21,7 +21,6 @@ from utils_summarization import (
|
||||
compute_token_type_ids,
|
||||
fit_to_block_size,
|
||||
build_mask,
|
||||
build_lm_labels,
|
||||
process_story,
|
||||
)
|
||||
|
||||
@@ -88,20 +87,6 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
expected_summary_lines = ["It was the best of times."]
|
||||
self.assertEqual(expected_summary_lines, summary_lines)
|
||||
|
||||
def test_build_lm_labels_no_padding(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4])
|
||||
expected = sequence
|
||||
np.testing.assert_array_equal(
|
||||
build_lm_labels(sequence, 0).numpy(), expected.numpy()
|
||||
)
|
||||
|
||||
def test_build_lm_labels(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4, 0, 0, 0])
|
||||
expected = torch.tensor([1, 2, 3, 4, -1, -1, -1])
|
||||
np.testing.assert_array_equal(
|
||||
build_lm_labels(sequence, 0).numpy(), expected.numpy()
|
||||
)
|
||||
|
||||
def test_build_mask_no_padding(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4])
|
||||
expected = torch.tensor([1, 1, 1, 1])
|
||||
@@ -125,7 +110,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 1, 1, 1, 0, 0]]
|
||||
[[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]]
|
||||
)
|
||||
|
||||
result = compute_token_type_ids(batch, separator)
|
||||
@@ -72,8 +72,7 @@ class ExamplesTests(unittest.TestCase):
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
testargs = ["run_squad.py",
|
||||
"--train_file=./examples/tests_samples/SQUAD/dev-v2.0-small.json",
|
||||
"--predict_file=./examples/tests_samples/SQUAD/dev-v2.0-small.json",
|
||||
"--data_dir=./examples/tests_samples/SQUAD",
|
||||
"--model_name=bert-base-uncased",
|
||||
"--output_dir=./examples/tests_samples/temp_dir",
|
||||
"--max_steps=10",
|
||||
|
||||
140
examples/tests_samples/SQUAD/train-v2.0.json
Normal file
140
examples/tests_samples/SQUAD/train-v2.0.json
Normal file
@@ -0,0 +1,140 @@
|
||||
{
|
||||
"version": "v2.0",
|
||||
"data": [{
|
||||
"title": "Normans",
|
||||
"paragraphs": [{
|
||||
"qas": [{
|
||||
"question": "In what country is Normandy located?",
|
||||
"id": "56ddde6b9a695914005b9628",
|
||||
"answers": [{
|
||||
"text": "France",
|
||||
"answer_start": 159
|
||||
}],
|
||||
"is_impossible": false
|
||||
}, {
|
||||
"question": "When were the Normans in Normandy?",
|
||||
"id": "56ddde6b9a695914005b9629",
|
||||
"answers": [{
|
||||
"text": "10th and 11th centuries",
|
||||
"answer_start": 94
|
||||
}],
|
||||
"is_impossible": false
|
||||
}, {
|
||||
"question": "From which countries did the Norse originate?",
|
||||
"id": "56ddde6b9a695914005b962a",
|
||||
"answers": [{
|
||||
"text": "Denmark, Iceland and Norway",
|
||||
"answer_start": 256
|
||||
}],
|
||||
"is_impossible": false
|
||||
}, {
|
||||
"plausible_answers": [{
|
||||
"text": "Rollo",
|
||||
"answer_start": 308
|
||||
}],
|
||||
"question": "Who did King Charles III swear fealty to?",
|
||||
"id": "5ad39d53604f3c001a3fe8d3",
|
||||
"answers": [],
|
||||
"is_impossible": true
|
||||
}, {
|
||||
"plausible_answers": [{
|
||||
"text": "10th century",
|
||||
"answer_start": 671
|
||||
}],
|
||||
"question": "When did the Frankish identity emerge?",
|
||||
"id": "5ad39d53604f3c001a3fe8d4",
|
||||
"answers": [],
|
||||
"is_impossible": true
|
||||
}],
|
||||
"context": "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries."
|
||||
}, {
|
||||
"qas": [{
|
||||
"question": "Who was the duke in the battle of Hastings?",
|
||||
"id": "56dddf4066d3e219004dad5f",
|
||||
"answers": [{
|
||||
"text": "William the Conqueror",
|
||||
"answer_start": 1022
|
||||
}],
|
||||
"is_impossible": false
|
||||
}, {
|
||||
"plausible_answers": [{
|
||||
"text": "Antioch",
|
||||
"answer_start": 1295
|
||||
}],
|
||||
"question": "What principality did William the conquerer found?",
|
||||
"id": "5ad3a266604f3c001a3fea2b",
|
||||
"answers": [],
|
||||
"is_impossible": true
|
||||
}],
|
||||
"context": "The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands."
|
||||
}]
|
||||
}, {
|
||||
"title": "Computational_complexity_theory",
|
||||
"paragraphs": [{
|
||||
"qas": [{
|
||||
"question": "What branch of theoretical computer science deals with broadly classifying computational problems by difficulty and class of relationship?",
|
||||
"id": "56e16182e3433e1400422e28",
|
||||
"answers": [{
|
||||
"text": "Computational complexity theory",
|
||||
"answer_start": 0
|
||||
}],
|
||||
"is_impossible": false
|
||||
}, {
|
||||
"plausible_answers": [{
|
||||
"text": "algorithm",
|
||||
"answer_start": 472
|
||||
}],
|
||||
"question": "What is a manual application of mathematical steps?",
|
||||
"id": "5ad5316b5b96ef001a10ab76",
|
||||
"answers": [],
|
||||
"is_impossible": true
|
||||
}],
|
||||
"context": "Computational complexity theory is a branch of the theory of computation in theoretical computer science that focuses on classifying computational problems according to their inherent difficulty, and relating those classes to each other. A computational problem is understood to be a task that is in principle amenable to being solved by a computer, which is equivalent to stating that the problem may be solved by mechanical application of mathematical steps, such as an algorithm."
|
||||
}, {
|
||||
"qas": [{
|
||||
"question": "What measure of a computational problem broadly defines the inherent difficulty of the solution?",
|
||||
"id": "56e16839cd28a01900c67887",
|
||||
"answers": [{
|
||||
"text": "if its solution requires significant resources",
|
||||
"answer_start": 46
|
||||
}],
|
||||
"is_impossible": false
|
||||
}, {
|
||||
"question": "What method is used to intuitively assess or quantify the amount of resources required to solve a computational problem?",
|
||||
"id": "56e16839cd28a01900c67888",
|
||||
"answers": [{
|
||||
"text": "mathematical models of computation",
|
||||
"answer_start": 176
|
||||
}],
|
||||
"is_impossible": false
|
||||
}, {
|
||||
"question": "What are two basic primary resources used to guage complexity?",
|
||||
"id": "56e16839cd28a01900c67889",
|
||||
"answers": [{
|
||||
"text": "time and storage",
|
||||
"answer_start": 305
|
||||
}],
|
||||
"is_impossible": false
|
||||
}, {
|
||||
"plausible_answers": [{
|
||||
"text": "the number of gates in a circuit",
|
||||
"answer_start": 436
|
||||
}],
|
||||
"question": "What unit is measured to determine circuit simplicity?",
|
||||
"id": "5ad532575b96ef001a10ab7f",
|
||||
"answers": [],
|
||||
"is_impossible": true
|
||||
}, {
|
||||
"plausible_answers": [{
|
||||
"text": "the number of processors",
|
||||
"answer_start": 502
|
||||
}],
|
||||
"question": "What number is used in perpendicular computing?",
|
||||
"id": "5ad532575b96ef001a10ab80",
|
||||
"answers": [],
|
||||
"is_impossible": true
|
||||
}],
|
||||
"context": "A problem is regarded as inherently difficult if its solution requires significant resources, whatever the algorithm used. The theory formalizes this intuition, by introducing mathematical models of computation to study these problems and quantifying the amount of resources needed to solve them, such as time and storage. Other complexity measures are also used, such as the amount of communication (used in communication complexity), the number of gates in a circuit (used in circuit complexity) and the number of processors (used in parallel computing). One of the roles of computational complexity theory is to determine the practical limits on what computers can and cannot do."
|
||||
}]
|
||||
}]
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,330 +0,0 @@
|
||||
""" Official evaluation script for SQuAD version 2.0.
|
||||
Modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0
|
||||
|
||||
In addition to basic functionality, we also compute additional statistics and
|
||||
plot precision-recall curves if an additional na_prob.json file is provided.
|
||||
This file is expected to map question ID's to the model's predicted probability
|
||||
that a question is unanswerable.
|
||||
"""
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
|
||||
class EVAL_OPTS():
|
||||
def __init__(self, data_file, pred_file, out_file="",
|
||||
na_prob_file="na_prob.json", na_prob_thresh=1.0,
|
||||
out_image_dir=None, verbose=False):
|
||||
self.data_file = data_file
|
||||
self.pred_file = pred_file
|
||||
self.out_file = out_file
|
||||
self.na_prob_file = na_prob_file
|
||||
self.na_prob_thresh = na_prob_thresh
|
||||
self.out_image_dir = out_image_dir
|
||||
self.verbose = verbose
|
||||
|
||||
OPTS = None
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.')
|
||||
parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.')
|
||||
parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.')
|
||||
parser.add_argument('--out-file', '-o', metavar='eval.json',
|
||||
help='Write accuracy metrics to file (default is stdout).')
|
||||
parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json',
|
||||
help='Model estimates of probability of no answer.')
|
||||
parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0,
|
||||
help='Predict "" if no-answer probability exceeds this (default = 1.0).')
|
||||
parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None,
|
||||
help='Save precision-recall curves to directory.')
|
||||
parser.add_argument('--verbose', '-v', action='store_true')
|
||||
if len(sys.argv) == 1:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
return parser.parse_args()
|
||||
|
||||
def make_qid_to_has_ans(dataset):
|
||||
qid_to_has_ans = {}
|
||||
for article in dataset:
|
||||
for p in article['paragraphs']:
|
||||
for qa in p['qas']:
|
||||
qid_to_has_ans[qa['id']] = bool(qa['answers'])
|
||||
return qid_to_has_ans
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
def remove_articles(text):
|
||||
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
||||
return re.sub(regex, ' ', text)
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
def get_tokens(s):
|
||||
if not s: return []
|
||||
return normalize_answer(s).split()
|
||||
|
||||
def compute_exact(a_gold, a_pred):
|
||||
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
|
||||
|
||||
def compute_f1(a_gold, a_pred):
|
||||
gold_toks = get_tokens(a_gold)
|
||||
pred_toks = get_tokens(a_pred)
|
||||
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
||||
num_same = sum(common.values())
|
||||
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
||||
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
||||
return int(gold_toks == pred_toks)
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(pred_toks)
|
||||
recall = 1.0 * num_same / len(gold_toks)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
def get_raw_scores(dataset, preds):
|
||||
exact_scores = {}
|
||||
f1_scores = {}
|
||||
for article in dataset:
|
||||
for p in article['paragraphs']:
|
||||
for qa in p['qas']:
|
||||
qid = qa['id']
|
||||
gold_answers = [a['text'] for a in qa['answers']
|
||||
if normalize_answer(a['text'])]
|
||||
if not gold_answers:
|
||||
# For unanswerable questions, only correct answer is empty string
|
||||
gold_answers = ['']
|
||||
if qid not in preds:
|
||||
print('Missing prediction for %s' % qid)
|
||||
continue
|
||||
a_pred = preds[qid]
|
||||
# Take max over all gold answers
|
||||
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
|
||||
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
|
||||
return exact_scores, f1_scores
|
||||
|
||||
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
||||
new_scores = {}
|
||||
for qid, s in scores.items():
|
||||
pred_na = na_probs[qid] > na_prob_thresh
|
||||
if pred_na:
|
||||
new_scores[qid] = float(not qid_to_has_ans[qid])
|
||||
else:
|
||||
new_scores[qid] = s
|
||||
return new_scores
|
||||
|
||||
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
||||
if not qid_list:
|
||||
total = len(exact_scores)
|
||||
return collections.OrderedDict([
|
||||
('exact', 100.0 * sum(exact_scores.values()) / total),
|
||||
('f1', 100.0 * sum(f1_scores.values()) / total),
|
||||
('total', total),
|
||||
])
|
||||
else:
|
||||
total = len(qid_list)
|
||||
return collections.OrderedDict([
|
||||
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
||||
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
||||
('total', total),
|
||||
])
|
||||
|
||||
def merge_eval(main_eval, new_eval, prefix):
|
||||
for k in new_eval:
|
||||
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
||||
|
||||
def plot_pr_curve(precisions, recalls, out_image, title):
|
||||
plt.step(recalls, precisions, color='b', alpha=0.2, where='post')
|
||||
plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b')
|
||||
plt.xlabel('Recall')
|
||||
plt.ylabel('Precision')
|
||||
plt.xlim([0.0, 1.05])
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.title(title)
|
||||
plt.savefig(out_image)
|
||||
plt.clf()
|
||||
|
||||
def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans,
|
||||
out_image=None, title=None):
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
true_pos = 0.0
|
||||
cur_p = 1.0
|
||||
cur_r = 0.0
|
||||
precisions = [1.0]
|
||||
recalls = [0.0]
|
||||
avg_prec = 0.0
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid_to_has_ans[qid]:
|
||||
true_pos += scores[qid]
|
||||
cur_p = true_pos / float(i+1)
|
||||
cur_r = true_pos / float(num_true_pos)
|
||||
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
|
||||
# i.e., if we can put a threshold after this point
|
||||
avg_prec += cur_p * (cur_r - recalls[-1])
|
||||
precisions.append(cur_p)
|
||||
recalls.append(cur_r)
|
||||
if out_image:
|
||||
plot_pr_curve(precisions, recalls, out_image, title)
|
||||
return {'ap': 100.0 * avg_prec}
|
||||
|
||||
def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs,
|
||||
qid_to_has_ans, out_image_dir):
|
||||
if out_image_dir and not os.path.exists(out_image_dir):
|
||||
os.makedirs(out_image_dir)
|
||||
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
|
||||
if num_true_pos == 0:
|
||||
return
|
||||
pr_exact = make_precision_recall_eval(
|
||||
exact_raw, na_probs, num_true_pos, qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, 'pr_exact.png'),
|
||||
title='Precision-Recall curve for Exact Match score')
|
||||
pr_f1 = make_precision_recall_eval(
|
||||
f1_raw, na_probs, num_true_pos, qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, 'pr_f1.png'),
|
||||
title='Precision-Recall curve for F1 score')
|
||||
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
|
||||
pr_oracle = make_precision_recall_eval(
|
||||
oracle_scores, na_probs, num_true_pos, qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, 'pr_oracle.png'),
|
||||
title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)')
|
||||
merge_eval(main_eval, pr_exact, 'pr_exact')
|
||||
merge_eval(main_eval, pr_f1, 'pr_f1')
|
||||
merge_eval(main_eval, pr_oracle, 'pr_oracle')
|
||||
|
||||
def histogram_na_prob(na_probs, qid_list, image_dir, name):
|
||||
if not qid_list:
|
||||
return
|
||||
x = [na_probs[k] for k in qid_list]
|
||||
weights = np.ones_like(x) / float(len(x))
|
||||
plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
|
||||
plt.xlabel('Model probability of no-answer')
|
||||
plt.ylabel('Proportion of dataset')
|
||||
plt.title('Histogram of no-answer probability: %s' % name)
|
||||
plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name))
|
||||
plt.clf()
|
||||
|
||||
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
||||
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||
cur_score = num_no_ans
|
||||
best_score = cur_score
|
||||
best_thresh = 0.0
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid not in scores: continue
|
||||
if qid_to_has_ans[qid]:
|
||||
diff = scores[qid]
|
||||
else:
|
||||
if preds[qid]:
|
||||
diff = -1
|
||||
else:
|
||||
diff = 0
|
||||
cur_score += diff
|
||||
if cur_score > best_score:
|
||||
best_score = cur_score
|
||||
best_thresh = na_probs[qid]
|
||||
return 100.0 * best_score / len(scores), best_thresh
|
||||
|
||||
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
||||
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||
cur_score = num_no_ans
|
||||
best_score = cur_score
|
||||
best_thresh = 0.0
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid not in scores: continue
|
||||
if qid_to_has_ans[qid]:
|
||||
diff = scores[qid]
|
||||
else:
|
||||
if preds[qid]:
|
||||
diff = -1
|
||||
else:
|
||||
diff = 0
|
||||
cur_score += diff
|
||||
if cur_score > best_score:
|
||||
best_score = cur_score
|
||||
best_thresh = na_probs[qid]
|
||||
|
||||
has_ans_score, has_ans_cnt = 0, 0
|
||||
for qid in qid_list:
|
||||
if not qid_to_has_ans[qid]: continue
|
||||
has_ans_cnt += 1
|
||||
|
||||
if qid not in scores: continue
|
||||
has_ans_score += scores[qid]
|
||||
|
||||
return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
|
||||
|
||||
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
main_eval['best_exact'] = best_exact
|
||||
main_eval['best_exact_thresh'] = exact_thresh
|
||||
main_eval['best_f1'] = best_f1
|
||||
main_eval['best_f1_thresh'] = f1_thresh
|
||||
|
||||
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
main_eval['best_exact'] = best_exact
|
||||
main_eval['best_exact_thresh'] = exact_thresh
|
||||
main_eval['best_f1'] = best_f1
|
||||
main_eval['best_f1_thresh'] = f1_thresh
|
||||
main_eval['has_ans_exact'] = has_ans_exact
|
||||
main_eval['has_ans_f1'] = has_ans_f1
|
||||
|
||||
def main(OPTS):
|
||||
with open(OPTS.data_file) as f:
|
||||
dataset_json = json.load(f)
|
||||
dataset = dataset_json['data']
|
||||
with open(OPTS.pred_file) as f:
|
||||
preds = json.load(f)
|
||||
if OPTS.na_prob_file:
|
||||
with open(OPTS.na_prob_file) as f:
|
||||
na_probs = json.load(f)
|
||||
else:
|
||||
na_probs = {k: 0.0 for k in preds}
|
||||
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
|
||||
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
exact_raw, f1_raw = get_raw_scores(dataset, preds)
|
||||
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
|
||||
OPTS.na_prob_thresh)
|
||||
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
|
||||
OPTS.na_prob_thresh)
|
||||
out_eval = make_eval_dict(exact_thresh, f1_thresh)
|
||||
if has_ans_qids:
|
||||
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
|
||||
merge_eval(out_eval, has_ans_eval, 'HasAns')
|
||||
if no_ans_qids:
|
||||
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
|
||||
merge_eval(out_eval, no_ans_eval, 'NoAns')
|
||||
if OPTS.na_prob_file:
|
||||
find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
|
||||
if OPTS.na_prob_file and OPTS.out_image_dir:
|
||||
run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs,
|
||||
qid_to_has_ans, OPTS.out_image_dir)
|
||||
histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns')
|
||||
histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns')
|
||||
if OPTS.out_file:
|
||||
with open(OPTS.out_file, 'w') as f:
|
||||
json.dump(out_eval, f)
|
||||
else:
|
||||
print(json.dumps(out_eval, indent=2))
|
||||
return out_eval
|
||||
|
||||
if __name__ == '__main__':
|
||||
OPTS = parse_args()
|
||||
if OPTS.out_image_dir:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
main(OPTS)
|
||||
Reference in New Issue
Block a user