From 11600edc6e4e6a5ce148ca1d617c9d7e58bc7a7c Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Sat, 31 Aug 2019 00:37:41 -0400 Subject: [PATCH] Rebase on master + DistilBERT head pruning patch --- pytorch_transformers/modeling_distilbert.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pytorch_transformers/modeling_distilbert.py b/pytorch_transformers/modeling_distilbert.py index 1a0bd2496c..d9a2f1a177 100644 --- a/pytorch_transformers/modeling_distilbert.py +++ b/pytorch_transformers/modeling_distilbert.py @@ -174,12 +174,16 @@ class MultiHeadSelfAttention(nn.Module): self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.pruned_heads = set() + def prune_heads(self, heads): attention_head_size = self.dim // self.n_heads if len(heads) == 0: return mask = torch.ones(self.n_heads, attention_head_size) + heads = set(heads) - self.pruned_heads for head in heads: + head -= sum(1 if h < head else 0 for h in self.pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -191,6 +195,7 @@ class MultiHeadSelfAttention(nn.Module): # Update hyper params self.n_heads = self.n_heads - len(heads) self.dim = attention_head_size * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) def forward(self, query, key, value, mask, head_mask = None): """ @@ -395,7 +400,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs) - def init_weights(self, module): + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, nn.Embedding): @@ -480,7 +485,7 @@ class DistilBertModel(DistilBertPreTrainedModel): self.embeddings = Embeddings(config) # Embeddings self.transformer = Transformer(config) # Encoder - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): old_embeddings = self.embeddings.word_embeddings @@ -568,7 +573,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) self.vocab_projector = nn.Linear(config.dim, config.vocab_size) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) @@ -642,7 +647,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): self.classifier = nn.Linear(config.dim, config.num_labels) self.dropout = nn.Dropout(config.seq_classif_dropout) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None): distilbert_output = self.distilbert(input_ids=input_ids, @@ -716,7 +721,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): assert config.num_labels == 2 self.dropout = nn.Dropout(config.qa_dropout) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None): distilbert_output = self.distilbert(input_ids=input_ids,