Update DistilBERT training code

This commit is contained in:
VictorSanh
2019-09-05 18:26:14 +00:00
parent f9453d15e5
commit dddd6b9927
3 changed files with 15 additions and 3 deletions

View File

@@ -9,6 +9,12 @@ DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and l
For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5 For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
). ).
## Setup
This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`.
**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breakings changes compared to v1.1.0). It is important to note that there is a small internal bug in the current version of PyTorch available on pip that causes a memory leak in our training/distillation. It has been recently fixed and will likely be integrated into the next release. For the moment, we recommend to [compile PyTorch from source](https://github.com/pytorch/pytorch#from-source). Please refer to [issue 1179](https://github.com/huggingface/pytorch-transformers/issues/1179) for more details.
## How to use DistilBERT ## How to use DistilBERT
PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT): PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):

View File

@@ -17,6 +17,7 @@
""" """
import os import os
import math import math
import psutil
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from tqdm import trange, tqdm from tqdm import trange, tqdm
import numpy as np import numpy as np
@@ -192,7 +193,7 @@ class Distiller:
x_prob = self.token_probs[token_ids.flatten()] x_prob = self.token_probs[token_ids.flatten()]
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item()) n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False) tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.uint8, device=token_ids.device) pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
pred_mask[tgt_ids] = 1 pred_mask[tgt_ids] = 1
pred_mask = pred_mask.view(bs, max_seq_len) pred_mask = pred_mask.view(bs, max_seq_len)
@@ -216,7 +217,7 @@ class Distiller:
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long() _token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
token_ids = token_ids.masked_scatter(pred_mask, _token_ids) token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
mlm_labels[1-pred_mask] = -1 mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
return token_ids, attn_mask, mlm_labels return token_ids, attn_mask, mlm_labels
@@ -379,9 +380,9 @@ class Distiller:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm) torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm) torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
self.scheduler.step()
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.scheduler.step()
def iter(self): def iter(self):
""" """
@@ -419,6 +420,8 @@ class Distiller:
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
def end_epoch(self): def end_epoch(self):
""" """
Finally arrived at the end of epoch (full pass on dataset). Finally arrived at the end of epoch (full pass on dataset).

View File

@@ -1 +1,4 @@
gitpython==3.0.2 gitpython==3.0.2
tensorboard>=1.14.0
tensorboardX==1.8
psutil==5.6.3