Reorganize examples (#9010)

* Reorganize example folder

* Continue reorganization

* Change requirements for tests

* Final cleanup

* Finish regroup with tests all passing

* Copyright

* Requirements and readme

* Make a full link for the documentation

* Address review comments

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Add symlink

* Reorg again

* Apply suggestions from code review

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>

* Adapt title

* Update to new strucutre

* Remove test

* Update READMEs

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2020-12-11 10:07:02 -05:00
committed by GitHub
parent 86896de064
commit 783d7d2629
215 changed files with 4454 additions and 1193 deletions

View File

@@ -0,0 +1,193 @@
# Distil*
Author: @VictorSanh
This folder contains the original code used to train Distil* as well as examples showcasing how to use DistilBERT, DistilRoBERTa and DistilGPT2.
**January 20, 2020 - Bug fixing** We have recently discovered and fixed [a bug](https://github.com/huggingface/transformers/commit/48cbf267c988b56c71a2380f748a3e6092ccaed3) in the evaluation of our `run_*.py` scripts that caused the reported metrics to be over-estimated on average. We have updated all the metrics with the latest runs.
**December 6, 2019 - Update** We release **DistilmBERT**: 92% of `bert-base-multilingual-cased` on XNLI. The model supports 104 different languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages).
**November 19, 2019 - Update** We release German **DistilBERT**: 98.8% of `bert-base-german-dbmdz-cased` on NER tasks.
**October 23, 2019 - Update** We release **DistilRoBERTa**: 95% of `RoBERTa-base`'s performance on GLUE, twice as fast as RoBERTa while being 35% smaller.
**October 3, 2019 - Update** We release our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108) explaining our approach on **DistilBERT**. It includes updated results and further experiments. We applied the same method to GPT2 and release the weights of **DistilGPT2**. DistilGPT2 is two times faster and 33% smaller than GPT2. **The paper supersedes our [previous blogpost](https://medium.com/huggingface/distilbert-8cf3380435b5) with a different distillation loss and better performances. Please use the paper as a reference when comparing/reporting results on DistilBERT.**
**September 19, 2019 - Update:** We fixed bugs in the code and released an updated version of the weights trained with a modification of the distillation loss. DistilBERT now reaches 99% of `BERT-base`'s performance on GLUE, and 86.9 F1 score on SQuAD v1.1 dev set (compared to 88.5 for `BERT-base`). We will publish a formal write-up of our approach in the near future!
## What is Distil*
Distil* is a class of compressed models that started with DistilBERT. DistilBERT stands for Distilled-BERT. DistilBERT is a small, fast, cheap and light Transformer model based on Bert architecture. It has 40% less parameters than `bert-base-uncased`, runs 60% faster while preserving 97% of BERT's performances as measured on the GLUE language understanding benchmark. DistilBERT is trained using knowledge distillation, a technique to compress a large model called the teacher into a smaller model called the student. By distillating Bert, we obtain a smaller Transformer model that bears a lot of similarities with the original BERT model while being lighter, smaller and faster to run. DistilBERT is thus an interesting option to put large-scaled trained Transformer model into production.
We have applied the same method to other Transformer architectures and released the weights:
- GPT2: on the [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) benchmark, GPT2 reaches a perplexity on the test set of 16.3 compared to 21.1 for **DistilGPT2** (after fine-tuning on the train set).
- RoBERTa: **DistilRoBERTa** reaches 95% of `RoBERTa-base`'s performance on GLUE while being twice faster and 35% smaller.
- German BERT: **German DistilBERT** reaches 99% of `bert-base-german-dbmdz-cased`'s performance on German NER (CoNLL-2003).
- Multilingual BERT: **DistilmBERT** reaches 92% of Multilingual BERT's performance on XNLI while being twice faster and 25% smaller. The model supports 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages).
For more information on DistilBERT, please refer to our [NeurIPS workshop paper](https://arxiv.org/abs/1910.01108).
Here are the results on the dev sets of GLUE:
| Model | Macro-score | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2| STS-B| WNLI |
| :---: | :---: | :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---:| :---: |
| BERT-base-uncased | **79.5** | 56.3 | 84.7 | 88.6 | 91.8 | 89.6 | 69.3 | 92.7 | 89.0 | 53.5 |
| DistilBERT-base-uncased | **77.0** | 51.3 | 82.1 | 87.5 | 89.2 | 88.5 | 59.9 | 91.3 | 86.9 | 56.3 |
| BERT-base-cased | **78.2** | 58.2 | 83.9 | 87.8 | 91.0 | 89.2 | 66.1 | 91.7 | 89.2 | 46.5 |
| DistilBERT-base-cased | **75.9** | 47.2 | 81.5 | 85.6 | 88.2 | 87.8 | 60.6 | 90.4 | 85.5 | 56.3 |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| RoBERTa-base (reported) | **83.2**/**86.4**<sup>2</sup> | 63.6 | 87.6 | 90.2 | 92.8 | 91.9 | 78.7 | 94.8 | 91.2 | 57.7<sup>3</sup> |
| DistilRoBERTa<sup>1</sup> | **79.0**/**82.3**<sup>2</sup> | 59.3 | 84.0 | 86.6 | 90.8 | 89.4 | 67.9 | 92.5 | 88.3 | 52.1 |
<sup>1</sup> We did not use the MNLI checkpoint for fine-tuning but directly perform transfer learning on the pre-trained DistilRoBERTa.
<sup>2</sup> Macro-score computed without WNLI.
<sup>3</sup> We compute this score ourselves for completeness.
Here are the results on the *test* sets for 6 of the languages available in XNLI. The results are computed in the zero shot setting (trained on the English portion and evaluated on the target language portion):
| Model | English | Spanish | Chinese | German | Arabic | Urdu |
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
| mBERT base cased (computed) | 82.1 | 74.6 | 69.1 | 72.3 | 66.4 | 58.5 |
| mBERT base uncased (reported)| 81.4 | 74.3 | 63.8 | 70.5 | 62.1 | 58.3 |
| DistilmBERT | 78.2 | 69.1 | 64.0 | 66.3 | 59.1 | 54.7 |
## Setup
This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`.
**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breaking changes compared to v1.1.0).
## How to use DistilBERT
Transformers includes five pre-trained Distil* models, currently only provided for English and German (we are investigating the possibility to train and release a multilingual version of DistilBERT):
- `distilbert-base-uncased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-uncased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters.
- `distilbert-base-uncased-distilled-squad`: A finetuned version of `distilbert-base-uncased` finetuned using (a second step of) knowledge distillation on SQuAD 1.0. This model reaches a F1 score of 86.9 on the dev set (for comparison, Bert `bert-base-uncased` version reaches a 88.5 F1 score).
- `distilbert-base-cased`: DistilBERT English language model pretrained on the same data used to pretrain Bert (concatenation of the Toronto Book Corpus and full English Wikipedia) using distillation with the supervision of the `bert-base-cased` version of Bert. The model has 6 layers, 768 dimension and 12 heads, totalizing 65M parameters.
- `distilbert-base-cased-distilled-squad`: A finetuned version of `distilbert-base-cased` finetuned using (a second step of) knowledge distillation on SQuAD 1.0. This model reaches a F1 score of 87.1 on the dev set (for comparison, Bert `bert-base-cased` version reaches a 88.7 F1 score).
- `distilbert-base-german-cased`: DistilBERT German language model pretrained on 1/2 of the data used to pretrain Bert using distillation with the supervision of the `bert-base-german-dbmdz-cased` version of German DBMDZ Bert. For NER tasks the model reaches a F1 score of 83.49 on the CoNLL-2003 test set (for comparison, `bert-base-german-dbmdz-cased` reaches a 84.52 F1 score), and a F1 score of 85.23 on the GermEval 2014 test set (`bert-base-german-dbmdz-cased` reaches a 86.89 F1 score).
- `distilgpt2`: DistilGPT2 English language model pretrained with the supervision of `gpt2` (the smallest version of GPT2) on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset. The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 124M parameters for GPT2). On average, DistilGPT2 is two times faster than GPT2.
- `distilroberta-base`: DistilRoBERTa English language model pretrained with the supervision of `roberta-base` solely on [OpenWebTextCorpus](https://skylion007.github.io/OpenWebTextCorpus/), a reproduction of OpenAI's WebText dataset (it is ~4 times less training data than the teacher RoBERTa). The model has 6 layers, 768 dimension and 12 heads, totalizing 82M parameters (compared to 125M parameters for RoBERTa-base). On average DistilRoBERTa is twice as fast as Roberta-base.
- `distilbert-base-multilingual-cased`: DistilmBERT multilingual model pretrained with the supervision of `bert-base-multilingual-cased` on the concatenation of Wikipedia in 104 different languages. The model supports the 104 languages listed [here](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages). The model has 6 layers, 768 dimension and 12 heads, totalizing 134M parameters (compared to 177M parameters for mBERT-base). On average DistilmBERT is twice as fast as mBERT-base.
Using DistilBERT is very similar to using BERT. DistilBERT share the same tokenizer as BERT's `bert-base-uncased` even though we provide a link to this tokenizer under the `DistilBertTokenizer` name to have a consistent naming between the library models.
```python
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = DistilBertModel.from_pretrained('distilbert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
```
Similarly, using the other Distil* models simply consists in calling the base classes with a different pretrained checkpoint:
- DistilBERT uncased: `model = DistilBertModel.from_pretrained('distilbert-base-uncased')`
- DistilGPT2: `model = GPT2Model.from_pretrained('distilgpt2')`
- DistilRoBERTa: `model = RobertaModel.from_pretrained('distilroberta-base')`
- DistilmBERT: `model = DistilBertModel.from_pretrained('distilbert-base-multilingual-cased')`
## How to train Distil*
In the following, we will explain how you can train DistilBERT.
### A. Preparing the data
The weights we release are trained using a concatenation of Toronto Book Corpus and English Wikipedia (same training data as the English version of BERT).
To avoid processing the data several time, we do it once and for all before the training. From now on, will suppose that you have a text file `dump.txt` which contains one sequence per line (a sequence being composed of one of several coherent sentences).
First, we will binarize the data, i.e. tokenize the data and convert each token in an index in our model's vocabulary.
```bash
python scripts/binarized_data.py \
--file_path data/dump.txt \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file data/binarized_text
```
Our implementation of masked language modeling loss follows [XLM](https://github.com/facebookresearch/XLM)'s one and smooths the probability of masking with a factor that put more emphasis on rare words. Thus we count the occurrences of each tokens in the data:
```bash
python scripts/token_counts.py \
--data_file data/binarized_text.bert-base-uncased.pickle \
--token_counts_dump data/token_counts.bert-base-uncased.pickle \
--vocab_size 30522
```
### B. Training
Training with distillation is really simple once you have pre-processed the data:
```bash
python train.py \
--student_type distilbert \
--student_config training_configs/distilbert-base-uncased.json \
--teacher_type bert \
--teacher_name bert-base-uncased \
--alpha_ce 5.0 --alpha_mlm 2.0 --alpha_cos 1.0 --alpha_clm 0.0 --mlm \
--freeze_pos_embs \
--dump_path serialization_dir/my_first_training \
--data_file data/binarized_text.bert-base-uncased.pickle \
--token_counts data/token_counts.bert-base-uncased.pickle \
--force # overwrites the `dump_path` if it already exists.
```
By default, this will launch a training on a single GPU (even if more are available on the cluster). Other parameters are available in the command line, please look in `train.py` or run `python train.py --help` to list them.
We highly encourage you to use distributed training for training DistilBERT as the training corpus is quite large. Here's an example that runs a distributed training on a single node having 4 GPUs:
```bash
export NODE_RANK=0
export N_NODES=1
export N_GPU_NODE=4
export WORLD_SIZE=4
export MASTER_PORT=<AN_OPEN_PORT>
export MASTER_ADDR=<I.P.>
pkill -f 'python -u train.py'
python -m torch.distributed.launch \
--nproc_per_node=$N_GPU_NODE \
--nnodes=$N_NODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
train.py \
--force \
--gpus $WORLD_SIZE \
--student_type distilbert \
--student_config training_configs/distilbert-base-uncased.json \
--teacher_type bert \
--teacher_name bert-base-uncased \
--alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --alpha_clm 0.0 --mlm \
--freeze_pos_embs \
--dump_path serialization_dir/my_first_training \
--data_file data/binarized_text.bert-base-uncased.pickle \
--token_counts data/token_counts.bert-base-uncased.pickle
```
**Tips:** Starting distilled training with good initialization of the model weights is crucial to reach decent performance. In our experiments, we initialized our model from a few layers of the teacher (Bert) itself! Please refer to `scripts/extract.py` and `scripts/extract_distilbert.py` to create a valid initialization checkpoint and use `--student_pretrained_weights` argument to use this initialization for the distilled training!
Happy distillation!
## Citation
If you find the resource useful, you should cite the following paper:
```
@inproceedings{sanh2019distilbert,
title={DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter},
author={Sanh, Victor and Debut, Lysandre and Chaumond, Julien and Wolf, Thomas},
booktitle={NeurIPS EMC^2 Workshop},
year={2019}
}
```

View File

@@ -0,0 +1,603 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The distiller to distil the student.
Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
import math
import os
import time
import psutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
from lm_seqs_dataset import LmSeqsDataset
from transformers import get_linear_schedule_with_warmup
from utils import logger
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
from tensorboardX import SummaryWriter
class Distiller:
def __init__(
self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
):
logger.info("Initializing Distiller")
self.params = params
self.dump_path = params.dump_path
self.multi_gpu = params.multi_gpu
self.fp16 = params.fp16
self.student = student
self.teacher = teacher
self.student_config = student.config
self.vocab_size = student.config.vocab_size
if params.n_gpu <= 1:
sampler = RandomSampler(dataset)
else:
sampler = DistributedSampler(dataset)
if params.group_by_size:
groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
else:
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
self.temperature = params.temperature
assert self.temperature > 0.0
self.alpha_ce = params.alpha_ce
self.alpha_mlm = params.alpha_mlm
self.alpha_clm = params.alpha_clm
self.alpha_mse = params.alpha_mse
self.alpha_cos = params.alpha_cos
self.mlm = params.mlm
if self.mlm:
logger.info("Using MLM loss for LM step.")
self.mlm_mask_prop = params.mlm_mask_prop
assert 0.0 <= self.mlm_mask_prop <= 1.0
assert params.word_mask + params.word_keep + params.word_rand == 1.0
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
if self.fp16:
self.pred_probs = self.pred_probs.half()
self.token_probs = self.token_probs.half()
else:
logger.info("Using CLM loss for LM step.")
self.epoch = 0
self.n_iter = 0
self.n_total_iter = 0
self.n_sequences_epoch = 0
self.total_loss_epoch = 0
self.last_loss = 0
self.last_loss_ce = 0
self.last_loss_mlm = 0
self.last_loss_clm = 0
if self.alpha_mse > 0.0:
self.last_loss_mse = 0
if self.alpha_cos > 0.0:
self.last_loss_cos = 0
self.last_log = 0
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
if self.alpha_mse > 0.0:
self.mse_loss_fct = nn.MSELoss(reduction="sum")
if self.alpha_cos > 0.0:
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
logger.info("--- Initializing model optimizer")
assert params.gradient_accumulation_steps >= 1
self.num_steps_epoch = len(self.dataloader)
num_train_optimization_steps = (
int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
],
"weight_decay": params.weight_decay,
},
{
"params": [
p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
],
"weight_decay": 0.0,
},
]
logger.info(
"------ Number of trainable parameters (student): %i"
% sum([p.numel() for p in self.student.parameters() if p.requires_grad])
)
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
self.optimizer = AdamW(
optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
)
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
)
if self.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
self.student, self.optimizer = amp.initialize(
self.student, self.optimizer, opt_level=self.params.fp16_opt_level
)
self.teacher = self.teacher.half()
if self.multi_gpu:
if self.fp16:
from apex.parallel import DistributedDataParallel
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(self.student)
else:
from torch.nn.parallel import DistributedDataParallel
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(
self.student,
device_ids=[params.local_rank],
output_device=params.local_rank,
find_unused_parameters=True,
)
self.is_master = params.is_master
if self.is_master:
logger.info("--- Initializing Tensorboard")
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
def prepare_batch_mlm(self, batch):
"""
Prepare the batch: from the token_ids and the lengths, compute the attention mask and the masked label for MLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. There is a -100 where there is nothing to predict.
"""
token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0)
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
bs, max_seq_len = token_ids.size()
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
x_prob = self.token_probs[token_ids.flatten()]
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
pred_mask = torch.zeros(
bs * max_seq_len, dtype=torch.bool, device=token_ids.device
) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
pred_mask[tgt_ids] = 1
pred_mask = pred_mask.view(bs, max_seq_len)
pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
# mask a number of words == 0 [8] (faster with fp16)
if self.fp16:
n1 = pred_mask.sum().item()
if n1 > 8:
pred_mask = pred_mask.view(-1)
n2 = max(n1 % 8, 8 * (n1 // 8))
if n2 != n1:
pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
pred_mask = pred_mask.view(bs, max_seq_len)
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
_token_ids_real = token_ids[pred_mask]
_token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
_token_ids = (
_token_ids_mask * (probs == 0).long()
+ _token_ids_real * (probs == 1).long()
+ _token_ids_rand * (probs == 2).long()
)
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
return token_ids, attn_mask, mlm_labels
def prepare_batch_clm(self, batch):
"""
Prepare the batch: from the token_ids and the lengths, compute the attention mask and the labels for CLM.
Input:
------
batch: `Tuple`
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.
Output:
-------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict.
"""
token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0)
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
return token_ids, attn_mask, clm_labels
def round_batch(self, x: torch.tensor, lengths: torch.tensor):
"""
For float16 only.
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
Input:
------
x: `torch.tensor(bs, seq_length)` - The token ids.
lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.
Output:
-------
x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
"""
if not self.fp16 or len(lengths) < 8:
return x, lengths
# number of sentences == 0 [8]
bs1 = len(lengths)
bs2 = 8 * (bs1 // 8)
assert bs2 > 0 and bs2 % 8 == 0
if bs1 != bs2:
idx = torch.randperm(bs1)[:bs2]
lengths = lengths[idx]
slen = lengths.max().item()
x = x[idx, :slen]
else:
idx = None
# sequence length == 0 [8]
ml1 = x.size(1)
if ml1 % 8 != 0:
pad = 8 - (ml1 % 8)
ml2 = ml1 + pad
if self.mlm:
pad_id = self.params.special_tok_ids["pad_token"]
else:
pad_id = self.params.special_tok_ids["unk_token"]
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
x = torch.cat([x, padding_tensor], 1)
assert x.size() == (bs2, ml2)
assert x.size(0) % 8 == 0
assert x.size(1) % 8 == 0
return x, lengths
def train(self):
"""
The real training loop.
"""
if self.is_master:
logger.info("Starting training")
self.last_log = time.time()
self.student.train()
self.teacher.eval()
for _ in range(self.params.n_epoch):
if self.is_master:
logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
if self.multi_gpu:
torch.distributed.barrier()
iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
for batch in iter_bar:
if self.params.n_gpu > 0:
batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
if self.mlm:
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
else:
token_ids, attn_mask, lm_labels = self.prepare_batch_clm(batch=batch)
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
iter_bar.update()
iter_bar.set_postfix(
{"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
)
iter_bar.close()
if self.is_master:
logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
self.end_epoch()
if self.is_master:
logger.info("Save very last checkpoint as `pytorch_model.bin`.")
self.save_checkpoint(checkpoint_name="pytorch_model.bin")
logger.info("Training is finished")
def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
"""
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation).
Input:
------
input_ids: `torch.tensor(bs, seq_length)` - The token ids.
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
"""
if self.mlm:
s_logits, s_hidden_states = self.student(
input_ids=input_ids, attention_mask=attention_mask
) # (bs, seq_length, voc_size)
with torch.no_grad():
t_logits, t_hidden_states = self.teacher(
input_ids=input_ids, attention_mask=attention_mask
) # (bs, seq_length, voc_size)
else:
s_logits, _, s_hidden_states = self.student(
input_ids=input_ids, attention_mask=None
) # (bs, seq_length, voc_size)
with torch.no_grad():
t_logits, _, t_hidden_states = self.teacher(
input_ids=input_ids, attention_mask=None
) # (bs, seq_length, voc_size)
assert s_logits.size() == t_logits.size()
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if self.params.restrict_ce_to_mask:
mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size)
else:
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_length, voc_size)
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = (
self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
loss = self.alpha_ce * loss_ce
if self.alpha_mlm > 0.0:
loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
loss += self.alpha_mlm * loss_mlm
if self.alpha_clm > 0.0:
shift_logits = s_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous()
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss += self.alpha_clm * loss_clm
if self.alpha_mse > 0.0:
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
0
) # Reproducing batchmean reduction
loss += self.alpha_mse * loss_mse
if self.alpha_cos > 0.0:
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
assert s_hidden_states.size() == t_hidden_states.size()
dim = s_hidden_states.size(-1)
s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
loss += self.alpha_cos * loss_cos
self.total_loss_epoch += loss.item()
self.last_loss = loss.item()
self.last_loss_ce = loss_ce.item()
if self.alpha_mlm > 0.0:
self.last_loss_mlm = loss_mlm.item()
if self.alpha_clm > 0.0:
self.last_loss_clm = loss_clm.item()
if self.alpha_mse > 0.0:
self.last_loss_mse = loss_mse.item()
if self.alpha_cos > 0.0:
self.last_loss_cos = loss_cos.item()
self.optimize(loss)
self.n_sequences_epoch += input_ids.size(0)
def optimize(self, loss):
"""
Normalization on the loss (gradient accumulation or distributed training), followed by
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
Also update the metrics for tensorboard.
"""
# Check for NaN
if (loss != loss).data.any():
logger.error("NaN detected")
exit()
if self.multi_gpu:
loss = loss.mean()
if self.params.gradient_accumulation_steps > 1:
loss = loss / self.params.gradient_accumulation_steps
if self.fp16:
from apex import amp
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
self.iter()
if self.n_iter % self.params.gradient_accumulation_steps == 0:
if self.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
def iter(self):
"""
Update global counts, write to tensorboard and save checkpoint.
"""
self.n_iter += 1
self.n_total_iter += 1
if self.n_total_iter % self.params.log_interval == 0:
self.log_tensorboard()
self.last_log = time.time()
if self.n_total_iter % self.params.checkpoint_interval == 0:
self.save_checkpoint()
def log_tensorboard(self):
"""
Log into tensorboard. Only by the master process.
"""
if not self.is_master:
return
for param_name, param in self.student.named_parameters():
self.tensorboard.add_scalar(
tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
)
if param.grad is None:
continue
self.tensorboard.add_scalar(
tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="losses/cum_avg_loss_epoch",
scalar_value=self.total_loss_epoch / self.n_iter,
global_step=self.n_total_iter,
)
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
self.tensorboard.add_scalar(
tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
)
if self.alpha_mlm > 0.0:
self.tensorboard.add_scalar(
tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
)
if self.alpha_clm > 0.0:
self.tensorboard.add_scalar(
tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
)
if self.alpha_mse > 0.0:
self.tensorboard.add_scalar(
tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
)
if self.alpha_cos > 0.0:
self.tensorboard.add_scalar(
tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="global/memory_usage",
scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
global_step=self.n_total_iter,
)
self.tensorboard.add_scalar(
tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
)
def end_epoch(self):
"""
Finally arrived at the end of epoch (full pass on dataset).
Do some tensorboard logging and checkpoint saving.
"""
logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
if self.is_master:
self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
self.tensorboard.add_scalar(
tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
)
self.epoch += 1
self.n_sequences_epoch = 0
self.n_iter = 0
self.total_loss_epoch = 0
def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
"""
Save the current state. Only by the master process.
"""
if not self.is_master:
return
mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
mdl_to_save.config.save_pretrained(self.dump_path)
state_dict = mdl_to_save.state_dict()
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))

View File

@@ -0,0 +1,108 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Adapted from PyTorch Vision (https://github.com/pytorch/vision/blob/master/references/detection/group_by_aspect_ratio.py)
"""
import bisect
import copy
from collections import defaultdict
import numpy as np
from torch.utils.data.sampler import BatchSampler, Sampler
from utils import logger
def _quantize(x, bins):
bins = copy.deepcopy(bins)
bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized
def create_lengths_groups(lengths, k=0):
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
groups = _quantize(lengths, bins)
# count number of elements per group
counts = np.unique(groups, return_counts=True)[1]
fbins = [0] + bins + [np.inf]
logger.info("Using {} as bins for aspect lengths quantization".format(fbins))
logger.info("Count of instances per bin: {}".format(counts))
return groups
class GroupedBatchSampler(BatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices.
It enforces that the batch only contain elements from the same group.
It also tries to provide mini-batches which follows an ordering which is
as close as possible to the ordering from the original sampler.
Arguments:
sampler (Sampler): Base sampler.
group_ids (list[int]): If the sampler produces indices in range [0, N),
`group_ids` must be a list of `N` ints which contains the group id of each sample.
The group ids must be a continuous set of integers starting from
0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch.
"""
def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = group_ids
self.batch_size = batch_size
def __iter__(self):
buffer_per_group = defaultdict(list)
samples_per_group = defaultdict(list)
num_batches = 0
for idx in self.sampler:
group_id = self.group_ids[idx]
buffer_per_group[group_id].append(idx)
samples_per_group[group_id].append(idx)
if len(buffer_per_group[group_id]) == self.batch_size:
yield buffer_per_group[group_id] # TODO
num_batches += 1
del buffer_per_group[group_id]
assert len(buffer_per_group[group_id]) < self.batch_size
# now we have run out of elements that satisfy
# the group criteria, let's return the remaining
# elements so that the size of the sampler is
# deterministic
expected_num_batches = len(self)
num_remaining = expected_num_batches - num_batches
if num_remaining > 0:
# for the remaining batches, group the batches by similar lengths
batch_idx = []
for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
batch_idx.extend(idxs)
if len(batch_idx) >= self.batch_size:
yield batch_idx[: self.batch_size]
batch_idx = batch_idx[self.batch_size :]
num_remaining -= 1
if len(batch_idx) > 0:
yield batch_idx
num_remaining -= 1
assert num_remaining == 0
def __len__(self):
"""
Return the number of mini-batches rather than the number of samples.
"""
return (len(self.sampler) + self.batch_size - 1) // self.batch_size

View File

@@ -0,0 +1,166 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Dataset to distilled models
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
import numpy as np
import torch
from torch.utils.data import Dataset
from utils import logger
class LmSeqsDataset(Dataset):
"""Custom Dataset wrapping language modeling sequences.
Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.
Input:
------
params: `NameSpace` parameters
data: `List[np.array[int]]
"""
def __init__(self, params, data):
self.params = params
self.token_ids = np.array(data)
self.lengths = np.array([len(t) for t in data])
self.check()
self.remove_long_sequences()
self.remove_empty_sequences()
self.remove_unknown_sequences()
self.check()
self.print_statistics()
def __getitem__(self, index):
return (self.token_ids[index], self.lengths[index])
def __len__(self):
return len(self.lengths)
def check(self):
"""
Some sanity checks
"""
assert len(self.token_ids) == len(self.lengths)
assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths)))
def remove_long_sequences(self):
"""
Sequences that are too long are split by chunk of max_model_input_size.
"""
max_len = self.params.max_model_input_size
indices = self.lengths > max_len
logger.info(f"Splitting {sum(indices)} too long sequences.")
def divide_chunks(l, n):
return [l[i : i + n] for i in range(0, len(l), n)]
new_tok_ids = []
new_lengths = []
if self.params.mlm:
cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
else:
cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
for seq_, len_ in zip(self.token_ids, self.lengths):
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
if len_ <= max_len:
new_tok_ids.append(seq_)
new_lengths.append(len_)
else:
sub_seqs = []
for sub_s in divide_chunks(seq_, max_len - 2):
if sub_s[0] != cls_id:
sub_s = np.insert(sub_s, 0, cls_id)
if sub_s[-1] != sep_id:
sub_s = np.insert(sub_s, len(sub_s), sep_id)
assert len(sub_s) <= max_len
assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s
sub_seqs.append(sub_s)
new_tok_ids.extend(sub_seqs)
new_lengths.extend([len(l) for l in sub_seqs])
self.token_ids = np.array(new_tok_ids)
self.lengths = np.array(new_lengths)
def remove_empty_sequences(self):
"""
Too short sequences are simply removed. This could be tuned.
"""
init_size = len(self)
indices = self.lengths > 11
self.token_ids = self.token_ids[indices]
self.lengths = self.lengths[indices]
new_size = len(self)
logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
def remove_unknown_sequences(self):
"""
Remove sequences with a (too) high level of unknown tokens.
"""
if "unk_token" not in self.params.special_tok_ids:
return
else:
unk_token_id = self.params.special_tok_ids["unk_token"]
init_size = len(self)
unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids])
indices = (unk_occs / self.lengths) < 0.5
self.token_ids = self.token_ids[indices]
self.lengths = self.lengths[indices]
new_size = len(self)
logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).")
def print_statistics(self):
"""
Print some statistics on the corpus. Only the master process.
"""
if not self.params.is_master:
return
logger.info(f"{len(self)} sequences")
# data_len = sum(self.lengths)
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
# unk_idx = self.params.special_tok_ids['unk_token']
# nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)')
def batch_sequences(self, batch):
"""
Do the padding and transform into torch.tensor.
"""
token_ids = [t[0] for t in batch]
lengths = [t[1] for t in batch]
assert len(token_ids) == len(lengths)
# Max for paddings
max_seq_len_ = max(lengths)
# Pad token ids
if self.params.mlm:
pad_idx = self.params.special_tok_ids["pad_token"]
else:
pad_idx = self.params.special_tok_ids["unk_token"]
tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
assert len(tk_) == len(token_ids)
assert all(len(t) == max_seq_len_ for t in tk_)
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
lg_t = torch.tensor(lengths) # (bs)
return tk_t, lg_t

View File

@@ -0,0 +1,7 @@
transformers
gitpython==3.0.2
tensorboard>=1.14.0
tensorboardX==1.8
psutil==5.6.6
scipy>=1.4.1

View File

@@ -0,0 +1,872 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This is the exact same script as `examples/question-answering/run_squad.py` (as of 2020, January 8th) with an additional and optional step of distillation."""
import argparse
import glob
import logging
import os
import random
import timeit
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
import transformers
from transformers import (
WEIGHTS_NAME,
AdamW,
BertConfig,
BertForQuestionAnswering,
BertTokenizer,
DistilBertConfig,
DistilBertForQuestionAnswering,
DistilBertTokenizer,
RobertaConfig,
RobertaForQuestionAnswering,
RobertaTokenizer,
XLMConfig,
XLMForQuestionAnswering,
XLMTokenizer,
XLNetConfig,
XLNetForQuestionAnswering,
XLNetTokenizer,
get_linear_schedule_with_warmup,
squad_convert_examples_to_features,
)
from transformers.data.metrics.squad_metrics import (
compute_predictions_log_probs,
compute_predictions_logits,
squad_evaluate,
)
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
from transformers.trainer_utils import is_main_process
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
from tensorboardX import SummaryWriter
logger = logging.getLogger(__name__)
MODEL_CLASSES = {
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
}
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def to_list(tensor):
return tensor.detach().cpu().tolist()
def train(args, train_dataset, model, tokenizer, teacher=None):
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
os.path.join(args.model_name_or_path, "scheduler.pt")
):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = 1
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
try:
# set global_step to gobal_step of last saved checkpoint from model path
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
global_step = int(checkpoint_suffix)
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
logger.info(" Starting fine-tuning.")
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
)
# Added here for reproductibility
set_seed(args)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
model.train()
if teacher is not None:
teacher.eval()
batch = tuple(t.to(args.device) for t in batch)
inputs = {
"input_ids": batch[0],
"attention_mask": batch[1],
"start_positions": batch[3],
"end_positions": batch[4],
}
if args.model_type != "distilbert":
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
if args.model_type in ["xlnet", "xlm"]:
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
if args.version_2_with_negative:
inputs.update({"is_impossible": batch[7]})
outputs = model(**inputs)
loss, start_logits_stu, end_logits_stu = outputs
# Distillation loss
if teacher is not None:
if "token_type_ids" not in inputs:
inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
with torch.no_grad():
start_logits_tea, end_logits_tea = teacher(
input_ids=inputs["input_ids"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"],
)
assert start_logits_tea.size() == start_logits_stu.size()
assert end_logits_tea.size() == end_logits_stu.size()
loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_start = (
loss_fct(
F.log_softmax(start_logits_stu / args.temperature, dim=-1),
F.softmax(start_logits_tea / args.temperature, dim=-1),
)
* (args.temperature ** 2)
)
loss_end = (
loss_fct(
F.log_softmax(end_logits_stu / args.temperature, dim=-1),
F.softmax(end_logits_tea / args.temperature, dim=-1),
)
* (args.temperature ** 2)
)
loss_ce = (loss_start + loss_end) / 2.0
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1
# Log metrics
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Only evaluate when single GPU otherwise metrics may not average well
if args.local_rank == -1 and args.evaluate_during_training:
results = evaluate(args, model, tokenizer)
for key, value in results.items():
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir)
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
if args.max_steps > 0 and global_step > args.max_steps:
train_iterator.close()
break
if args.local_rank in [-1, 0]:
tb_writer.close()
return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""):
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
all_results = []
start_time = timeit.default_timer()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval()
batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad():
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
if args.model_type != "distilbert":
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
example_indices = batch[3]
if args.model_type in ["xlnet", "xlm"]:
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
outputs = model(**inputs)
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
output = [to_list(output[i]) for output in outputs]
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
# models only use two.
if len(output) >= 5:
start_logits = output[0]
start_top_index = output[1]
end_logits = output[2]
end_top_index = output[3]
cls_logits = output[4]
result = SquadResult(
unique_id,
start_logits,
end_logits,
start_top_index=start_top_index,
end_top_index=end_top_index,
cls_logits=cls_logits,
)
else:
start_logits, end_logits = output
result = SquadResult(unique_id, start_logits, end_logits)
all_results.append(result)
evalTime = timeit.default_timer() - start_time
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
# Compute predictions
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
if args.version_2_with_negative:
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
else:
output_null_log_odds_file = None
if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure
predictions = compute_predictions_log_probs(
examples,
features,
all_results,
args.n_best_size,
args.max_answer_length,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
model.config.start_n_top,
model.config.end_n_top,
args.version_2_with_negative,
tokenizer,
args.verbose_logging,
)
else:
predictions = compute_predictions_logits(
examples,
features,
all_results,
args.n_best_size,
args.max_answer_length,
args.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
args.verbose_logging,
args.version_2_with_negative,
args.null_score_diff_threshold,
tokenizer,
)
# Compute the F1 and exact scores.
results = squad_evaluate(examples, predictions)
return results
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0] and not evaluate:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier()
# Load data features from cache or dataset file
input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join(
os.path.dirname(input_file),
"cached_distillation_{}_{}_{}".format(
"dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length),
),
)
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
features_and_dataset = torch.load(cached_features_file)
try:
features, dataset, examples = (
features_and_dataset["features"],
features_and_dataset["dataset"],
features_and_dataset["examples"],
)
except KeyError:
raise DeprecationWarning(
"You seem to be loading features from an older version of this script please delete the "
"file %s in order for it to be created again" % cached_features_file
)
else:
logger.info("Creating features from dataset file at %s", input_file)
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
if evaluate:
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
else:
examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
features, dataset = squad_convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
return_dataset="pt",
threads=args.threads,
)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
if args.local_rank == 0 and not evaluate:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier()
if output_examples:
return dataset, examples, features
return dataset
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints and predictions will be written.",
)
# Distillation parameters (optional)
parser.add_argument(
"--teacher_type",
default=None,
type=str,
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
)
parser.add_argument(
"--teacher_name_or_path",
default=None,
type=str,
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.",
)
parser.add_argument(
"--alpha_ce", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
)
parser.add_argument(
"--alpha_squad", default=0.5, type=float, help="True SQuAD loss linear weight. Only for distillation."
)
parser.add_argument(
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
)
# Other parameters
parser.add_argument(
"--data_dir",
default=None,
type=str,
help="The input data dir. Should contain the .json files for the task."
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
"--train_file",
default=None,
type=str,
help="The input training file. If a data dir is specified, will look for the file there"
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
"--predict_file",
default=None,
type=str,
help="The input evaluation file. If a data dir is specified, will look for the file there"
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)
parser.add_argument(
"--tokenizer_name",
default="",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
)
parser.add_argument(
"--version_2_with_negative",
action="store_true",
help="If true, the SQuAD examples contain some that do not have an answer.",
)
parser.add_argument(
"--null_score_diff_threshold",
type=float,
default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.",
)
parser.add_argument(
"--max_seq_length",
default=384,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.",
)
parser.add_argument(
"--doc_stride",
default=128,
type=int,
help="When splitting up a long document into chunks, how much stride to take between chunks.",
)
parser.add_argument(
"--max_query_length",
default=64,
type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.",
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument(
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
)
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument(
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument(
"--n_best_size",
default=20,
type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
)
parser.add_argument(
"--max_answer_length",
default=30,
type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.",
)
parser.add_argument(
"--verbose_logging",
action="store_true",
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.",
)
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--eval_all_checkpoints",
action="store_true",
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
parser.add_argument(
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
args = parser.parse_args()
if (
os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
)
# Setup distant debugging if needed
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
# Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1
args.device = device
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank,
device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(args.local_rank):
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Set seed
set_seed(args)
# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None,
)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
)
model = model_class.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None,
)
if args.teacher_type is not None:
assert args.teacher_name_or_path is not None
assert args.alpha_ce > 0.0
assert args.alpha_ce + args.alpha_squad > 0.0
assert args.teacher_type != "distilbert", "We constraint teachers not to be of type DistilBERT."
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
teacher_config = teacher_config_class.from_pretrained(
args.teacher_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None
)
teacher = teacher_model_class.from_pretrained(
args.teacher_name_or_path, config=teacher_config, cache_dir=args.cache_dir if args.cache_dir else None
)
teacher.to(args.device)
else:
teacher = None
if args.local_rank == 0:
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
model.to(args.device)
logger.info("Training/evaluation parameters %s", args)
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
# remove the need for this code, but it is still valid.
if args.fp16:
try:
import apex
apex.amp.register_half_function(torch, "einsum")
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
# Training
if args.do_train:
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Save the trained model and the tokenizer
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
if args.do_train:
logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
# Reload the model
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint)
model.to(args.device)
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
results.update(result)
logger.info("Results: {}".format(results))
return results
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,96 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before distillation.
"""
import argparse
import logging
import pickle
import random
import time
import numpy as np
from transformers import BertTokenizer, GPT2Tokenizer, RobertaTokenizer
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(
description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)."
)
parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.")
parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"])
parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.")
parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.")
args = parser.parse_args()
logger.info(f"Loading Tokenizer ({args.tokenizer_name})")
if args.tokenizer_type == "bert":
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]`
sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]`
elif args.tokenizer_type == "roberta":
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map["cls_token"] # `<s>`
sep = tokenizer.special_tokens_map["sep_token"] # `</s>`
elif args.tokenizer_type == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>`
sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>`
logger.info(f"Loading text from {args.file_path}")
with open(args.file_path, "r", encoding="utf8") as fp:
data = fp.readlines()
logger.info("Start encoding")
logger.info(f"{len(data)} examples to process.")
rslt = []
iter = 0
interval = 10000
start = time.time()
for text in data:
text = f"{bos} {text.strip()} {sep}"
token_ids = tokenizer.encode(text, add_special_tokens=False)
rslt.append(token_ids)
iter += 1
if iter % interval == 0:
end = time.time()
logger.info(f"{iter} examples processed. - {(end-start):.2f}s/{interval}expl")
start = time.time()
logger.info("Finished binarization")
logger.info(f"{len(data)} examples processed.")
dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
vocab_size = tokenizer.vocab_size
if vocab_size < (1 << 16):
rslt_ = [np.uint16(d) for d in rslt]
else:
rslt_ = [np.int32(d) for d in rslt]
random.shuffle(rslt_)
logger.info(f"Dump to {dp_file}")
with open(dp_file, "wb") as handle:
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,102 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training the distilled model.
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
"""
import argparse
import torch
from transformers import GPT2LMHeadModel, RobertaForMaskedLM
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
parser.add_argument("--model_name", default="roberta-large", type=str)
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args()
if args.model_type == "roberta":
model = RobertaForMaskedLM.from_pretrained(args.model_name)
prefix = "roberta"
elif args.model_type == "gpt2":
model = GPT2LMHeadModel.from_pretrained(args.model_name)
prefix = "transformer"
state_dict = model.state_dict()
compressed_sd = {}
# Embeddings #
if args.model_type == "gpt2":
for param_name in ["wte.weight", "wpe.weight"]:
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
else:
for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
param_name = f"{prefix}.embeddings.{w}.weight"
compressed_sd[param_name] = state_dict[param_name]
for w in ["weight", "bias"]:
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
compressed_sd[param_name] = state_dict[param_name]
# Transformer Blocks #
std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == "gpt2":
for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
for w in ["weight", "bias"]:
compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
f"{prefix}.h.{teacher_idx}.{layer}.{w}"
]
compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
else:
for layer in [
"attention.self.query",
"attention.self.key",
"attention.self.value",
"attention.output.dense",
"attention.output.LayerNorm",
"intermediate.dense",
"output.dense",
"output.LayerNorm",
]:
for w in ["weight", "bias"]:
compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
]
std_idx += 1
# Language Modeling Head ###s
if args.model_type == "roberta":
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
if args.vocab_transform:
for w in ["weight", "bias"]:
compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
elif args.model_type == "gpt2":
for w in ["weight", "bias"]:
compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"]
print(f"N layers selected for distillation: {std_idx}")
print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}")
print(f"Save transferred checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint)

View File

@@ -0,0 +1,92 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training DistilBERT.
Specific to BERT -> DistilBERT.
"""
import argparse
import torch
from transformers import BertForMaskedLM
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default="bert-base-uncased", type=str)
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args()
if args.model_type == "bert":
model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = "bert"
else:
raise ValueError('args.model_type should be "bert".')
state_dict = model.state_dict()
compressed_sd = {}
for w in ["word_embeddings", "position_embeddings"]:
compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
for w in ["weight", "bias"]:
compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ["weight", "bias"]:
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
]
std_idx += 1
compressed_sd["vocab_projector.weight"] = state_dict["cls.predictions.decoder.weight"]
compressed_sd["vocab_projector.bias"] = state_dict["cls.predictions.bias"]
if args.vocab_transform:
for w in ["weight", "bias"]:
compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
print(f"N layers selected for distillation: {std_idx}")
print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}")
print(f"Save transferred checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint)

View File

@@ -0,0 +1,56 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training the distilled model.
"""
import argparse
import logging
import pickle
from collections import Counter
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
logger = logging.getLogger(__name__)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)"
)
parser.add_argument(
"--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset."
)
parser.add_argument(
"--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file."
)
parser.add_argument("--vocab_size", default=30522, type=int)
args = parser.parse_args()
logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, "rb") as fp:
data = pickle.load(fp)
logger.info("Counting occurences for MLM.")
counter = Counter()
for tk_ids in data:
counter.update(tk_ids)
counts = [0] * args.vocab_size
for k, v in counter.items():
counts[k] = v
logger.info(f"Dump to {args.token_counts_dump}")
with open(args.token_counts_dump, "wb") as handle:
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)

View File

@@ -0,0 +1,322 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training the distilled model.
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
"""
import argparse
import json
import os
import pickle
import shutil
import numpy as np
import torch
from distiller import Distiller
from lm_seqs_dataset import LmSeqsDataset
from transformers import (
BertConfig,
BertForMaskedLM,
BertTokenizer,
DistilBertConfig,
DistilBertForMaskedLM,
DistilBertTokenizer,
GPT2Config,
GPT2LMHeadModel,
GPT2Tokenizer,
RobertaConfig,
RobertaForMaskedLM,
RobertaTokenizer,
)
from utils import git_log, init_gpu_params, logger, set_seed
MODEL_CLASSES = {
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
}
def sanity_checks(args):
"""
A bunch of args sanity checks to perform even starting...
"""
assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0)
assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0)
if args.mlm:
assert os.path.isfile(args.token_counts)
assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"])
else:
assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"])
assert args.teacher_type == args.student_type or (
args.student_type == "distilbert" and args.teacher_type == "bert"
)
assert os.path.isfile(args.student_config)
if args.student_pretrained_weights is not None:
assert os.path.isfile(args.student_pretrained_weights)
if args.freeze_token_type_embds:
assert args.student_type in ["roberta"]
assert args.alpha_ce >= 0.0
assert args.alpha_mlm >= 0.0
assert args.alpha_clm >= 0.0
assert args.alpha_mse >= 0.0
assert args.alpha_cos >= 0.0
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.0
def freeze_pos_embeddings(student, args):
if args.student_type == "roberta":
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
elif args.student_type == "gpt2":
student.transformer.wpe.weight.requires_grad = False
def freeze_token_type_embeddings(student, args):
if args.student_type == "roberta":
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
def main():
parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.")
parser.add_argument(
"--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)"
)
parser.add_argument(
"--data_file",
type=str,
required=True,
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
)
parser.add_argument(
"--student_type",
type=str,
choices=["distilbert", "roberta", "gpt2"],
required=True,
help="The student type (DistilBERT, RoBERTa).",
)
parser.add_argument("--student_config", type=str, required=True, help="Path to the student configuration.")
parser.add_argument(
"--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint."
)
parser.add_argument(
"--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa)."
)
parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.")
parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.")
parser.add_argument(
"--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0."
)
parser.add_argument(
"--alpha_mlm",
default=0.0,
type=float,
help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.",
)
parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.")
parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument(
"--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0."
)
parser.add_argument(
"--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
)
parser.add_argument(
"--mlm_mask_prop",
default=0.15,
type=float,
help="Proportion of tokens for which we need to make a prediction.",
)
parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.")
parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.")
parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.")
parser.add_argument(
"--mlm_smoothing",
default=0.7,
type=float,
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
)
parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.")
parser.add_argument(
"--restrict_ce_to_mask",
action="store_true",
help="If true, compute the distilation loss only the [MLM] prediction distribution.",
)
parser.add_argument(
"--freeze_pos_embs",
action="store_true",
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
)
parser.add_argument(
"--freeze_token_type_embds",
action="store_true",
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
)
parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.")
parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).")
parser.add_argument(
"--group_by_size",
action="store_false",
help="If true, group sequences that have similar length into the same batch. Default is true.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=50,
help="Gradient accumulation for larger training batches.",
)
parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--gpus", type=int, default=1, help="Number of GPUs in the node.")
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
parser.add_argument("--seed", type=int, default=56, help="Random seed")
parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.")
parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.")
args = parser.parse_args()
sanity_checks(args)
# ARGS #
init_gpu_params(args)
set_seed(args)
if args.is_master:
if os.path.exists(args.dump_path):
if not args.force:
raise ValueError(
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
"Use `--force` if you want to overwrite it"
)
else:
shutil.rmtree(args.dump_path)
if not os.path.exists(args.dump_path):
os.makedirs(args.dump_path)
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
# SAVE PARAMS #
logger.info(f"Param: {args}")
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
json.dump(vars(args), f, indent=4)
git_log(args.dump_path)
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
# TOKENIZER #
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
special_tok_ids = {}
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
idx = tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
logger.info(f"Special tokens {special_tok_ids}")
args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
# DATA LOADER #
logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, "rb") as fp:
data = pickle.load(fp)
if args.mlm:
logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)")
with open(args.token_counts, "rb") as fp:
counts = pickle.load(fp)
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
for idx in special_tok_ids.values():
token_probs[idx] = 0.0 # do not predict special tokens
token_probs = torch.from_numpy(token_probs)
else:
token_probs = None
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info("Data loader created.")
# STUDENT #
logger.info(f"Loading student config from {args.student_config}")
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
stu_architecture_config.output_hidden_states = True
if args.student_pretrained_weights is not None:
logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}")
student = student_model_class.from_pretrained(args.student_pretrained_weights, config=stu_architecture_config)
else:
student = student_model_class(stu_architecture_config)
if args.n_gpu > 0:
student.to(f"cuda:{args.local_rank}")
logger.info("Student loaded.")
# TEACHER #
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0:
teacher.to(f"cuda:{args.local_rank}")
logger.info(f"Teacher loaded from {args.teacher_name}.")
# FREEZING #
if args.freeze_pos_embs:
freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args)
# SANITY CHECKS #
assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size
# DISTILLER #
torch.cuda.empty_cache()
distiller = Distiller(
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
)
distiller.train()
logger.info("Let's go get some drinks.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,15 @@
{
"activation": "gelu",
"attention_dropout": 0.1,
"dim": 768,
"dropout": 0.1,
"hidden_dim": 3072,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"n_heads": 12,
"n_layers": 6,
"sinusoidal_pos_embds": true,
"tie_weights_": true,
"vocab_size": 28996
}

View File

@@ -0,0 +1,15 @@
{
"activation": "gelu",
"attention_dropout": 0.1,
"dim": 768,
"dropout": 0.1,
"hidden_dim": 3072,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"n_heads": 12,
"n_layers": 6,
"sinusoidal_pos_embds": true,
"tie_weights_": true,
"vocab_size": 119547
}

View File

@@ -0,0 +1,15 @@
{
"activation": "gelu",
"attention_dropout": 0.1,
"dim": 768,
"dropout": 0.1,
"hidden_dim": 3072,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"n_heads": 12,
"n_layers": 6,
"sinusoidal_pos_embds": true,
"tie_weights_": true,
"vocab_size": 30522
}

View File

@@ -0,0 +1,10 @@
{
"initializer_range": 0.02,
"layer_norm_epsilon": 0.00001,
"n_ctx": 1024,
"n_embd": 768,
"n_head": 12,
"n_layer": 6,
"n_positions": 1024,
"vocab_size": 50257
}

View File

@@ -0,0 +1,14 @@
{
"vocab_size": 50265,
"hidden_size": 768,
"num_hidden_layers": 6,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 514,
"type_vocab_size": 1,
"initializer_range": 0.02,
"layer_norm_eps": 0.00001
}

View File

@@ -0,0 +1,133 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Utils to train DistilBERT
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
import json
import logging
import os
import socket
import git
import numpy as np
import torch
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def git_log(folder_path: str):
"""
Log commit info.
"""
repo = git.Repo(search_parent_directories=True)
repo_infos = {
"repo_id": str(repo),
"repo_sha": str(repo.head.object.hexsha),
"repo_branch": str(repo.active_branch),
}
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
json.dump(repo_infos, f, indent=4)
def init_gpu_params(params):
"""
Handle single and multi-GPU / multi-node.
"""
if params.n_gpu <= 0:
params.local_rank = 0
params.master_port = -1
params.is_master = True
params.multi_gpu = False
return
assert torch.cuda.is_available()
logger.info("Initializing GPUs")
if params.n_gpu > 1:
assert params.local_rank != -1
params.world_size = int(os.environ["WORLD_SIZE"])
params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
params.global_rank = int(os.environ["RANK"])
# number of nodes / node ID
params.n_nodes = params.world_size // params.n_gpu_per_node
params.node_id = params.global_rank // params.n_gpu_per_node
params.multi_gpu = True
assert params.n_nodes == int(os.environ["N_NODES"])
assert params.node_id == int(os.environ["NODE_RANK"])
# local job (single GPU)
else:
assert params.local_rank == -1
params.n_nodes = 1
params.node_id = 0
params.local_rank = 0
params.global_rank = 0
params.world_size = 1
params.n_gpu_per_node = 1
params.multi_gpu = False
# sanity checks
assert params.n_nodes >= 1
assert 0 <= params.node_id < params.n_nodes
assert 0 <= params.local_rank <= params.global_rank < params.world_size
assert params.world_size == params.n_nodes * params.n_gpu_per_node
# define whether this is the master process / if we are in multi-node distributed mode
params.is_master = params.node_id == 0 and params.local_rank == 0
params.multi_node = params.n_nodes > 1
# summary
PREFIX = f"--- Global rank: {params.global_rank} - "
logger.info(PREFIX + "Number of nodes: %i" % params.n_nodes)
logger.info(PREFIX + "Node ID : %i" % params.node_id)
logger.info(PREFIX + "Local rank : %i" % params.local_rank)
logger.info(PREFIX + "World size : %i" % params.world_size)
logger.info(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node)
logger.info(PREFIX + "Master : %s" % str(params.is_master))
logger.info(PREFIX + "Multi-node : %s" % str(params.multi_node))
logger.info(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu))
logger.info(PREFIX + "Hostname : %s" % socket.gethostname())
# set GPU device
torch.cuda.set_device(params.local_rank)
# initialize multi-GPU
if params.multi_gpu:
logger.info("Initializing PyTorch distributed")
torch.distributed.init_process_group(
init_method="env://",
backend="nccl",
)
def set_seed(args):
"""
Set the random seed.
"""
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)