Update DistilBERT training code
This commit is contained in:
@@ -9,6 +9,12 @@ DistilBERT stands for Distillated-BERT. DistilBERT is a small, fast, cheap and l
|
|||||||
For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
|
For more information on DistilBERT, please refer to our [detailed blog post](https://medium.com/huggingface/smaller-faster-cheaper-lighter-introducing-distilbert-a-distilled-version-of-bert-8cf3380435b5
|
||||||
).
|
).
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
This part of the library has only be tested with Python3.6+. There are few specific dependencies to install before launching a distillation, you can install them with the command `pip install -r requirements.txt`.
|
||||||
|
|
||||||
|
**Important note:** The training scripts have been updated to support PyTorch v1.2.0 (there are breakings changes compared to v1.1.0). It is important to note that there is a small internal bug in the current version of PyTorch available on pip that causes a memory leak in our training/distillation. It has been recently fixed and will likely be integrated into the next release. For the moment, we recommend to [compile PyTorch from source](https://github.com/pytorch/pytorch#from-source). Please refer to [issue 1179](https://github.com/huggingface/pytorch-transformers/issues/1179) for more details.
|
||||||
|
|
||||||
## How to use DistilBERT
|
## How to use DistilBERT
|
||||||
|
|
||||||
PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
|
PyTorch-Transformers includes two pre-trained DistilBERT models, currently only provided for English (we are investigating the possibility to train and release a multilingual version of DistilBERT):
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
|
import psutil
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from tqdm import trange, tqdm
|
from tqdm import trange, tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -192,7 +193,7 @@ class Distiller:
|
|||||||
x_prob = self.token_probs[token_ids.flatten()]
|
x_prob = self.token_probs[token_ids.flatten()]
|
||||||
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
|
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
|
||||||
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
|
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
|
||||||
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.uint8, device=token_ids.device)
|
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
|
||||||
pred_mask[tgt_ids] = 1
|
pred_mask[tgt_ids] = 1
|
||||||
pred_mask = pred_mask.view(bs, max_seq_len)
|
pred_mask = pred_mask.view(bs, max_seq_len)
|
||||||
|
|
||||||
@@ -216,7 +217,7 @@ class Distiller:
|
|||||||
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
|
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
|
||||||
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
|
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
|
||||||
|
|
||||||
mlm_labels[1-pred_mask] = -1
|
mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
|
||||||
|
|
||||||
return token_ids, attn_mask, mlm_labels
|
return token_ids, attn_mask, mlm_labels
|
||||||
|
|
||||||
@@ -379,9 +380,9 @@ class Distiller:
|
|||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
|
||||||
self.scheduler.step()
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
def iter(self):
|
def iter(self):
|
||||||
"""
|
"""
|
||||||
@@ -419,6 +420,8 @@ class Distiller:
|
|||||||
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
|
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter)
|
||||||
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
|
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter)
|
||||||
|
|
||||||
|
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter)
|
||||||
|
|
||||||
def end_epoch(self):
|
def end_epoch(self):
|
||||||
"""
|
"""
|
||||||
Finally arrived at the end of epoch (full pass on dataset).
|
Finally arrived at the end of epoch (full pass on dataset).
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
gitpython==3.0.2
|
gitpython==3.0.2
|
||||||
|
tensorboard>=1.14.0
|
||||||
|
tensorboardX==1.8
|
||||||
|
psutil==5.6.3
|
||||||
|
|||||||
Reference in New Issue
Block a user