Rebase on master + DistilBERT head pruning patch
This commit is contained in:
@@ -174,12 +174,16 @@ class MultiHeadSelfAttention(nn.Module):
|
|||||||
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
||||||
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
|
||||||
|
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
attention_head_size = self.dim // self.n_heads
|
attention_head_size = self.dim // self.n_heads
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_heads, attention_head_size)
|
mask = torch.ones(self.n_heads, attention_head_size)
|
||||||
|
heads = set(heads) - self.pruned_heads
|
||||||
for head in heads:
|
for head in heads:
|
||||||
|
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
@@ -191,6 +195,7 @@ class MultiHeadSelfAttention(nn.Module):
|
|||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.n_heads = self.n_heads - len(heads)
|
self.n_heads = self.n_heads - len(heads)
|
||||||
self.dim = attention_head_size * self.n_heads
|
self.dim = attention_head_size * self.n_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def forward(self, query, key, value, mask, head_mask = None):
|
def forward(self, query, key, value, mask, head_mask = None):
|
||||||
"""
|
"""
|
||||||
@@ -395,7 +400,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
|||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||||
|
|
||||||
def init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
if isinstance(module, nn.Embedding):
|
if isinstance(module, nn.Embedding):
|
||||||
@@ -480,7 +485,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
self.embeddings = Embeddings(config) # Embeddings
|
self.embeddings = Embeddings(config) # Embeddings
|
||||||
self.transformer = Transformer(config) # Encoder
|
self.transformer = Transformer(config) # Encoder
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
old_embeddings = self.embeddings.word_embeddings
|
old_embeddings = self.embeddings.word_embeddings
|
||||||
@@ -568,7 +573,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
|||||||
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
|
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
|
||||||
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
|
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||||
@@ -642,7 +647,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
|||||||
self.classifier = nn.Linear(config.dim, config.num_labels)
|
self.classifier = nn.Linear(config.dim, config.num_labels)
|
||||||
self.dropout = nn.Dropout(config.seq_classif_dropout)
|
self.dropout = nn.Dropout(config.seq_classif_dropout)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
|
def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
|
||||||
distilbert_output = self.distilbert(input_ids=input_ids,
|
distilbert_output = self.distilbert(input_ids=input_ids,
|
||||||
@@ -716,7 +721,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
|||||||
assert config.num_labels == 2
|
assert config.num_labels == 2
|
||||||
self.dropout = nn.Dropout(config.qa_dropout)
|
self.dropout = nn.Dropout(config.qa_dropout)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
|
def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
|
||||||
distilbert_output = self.distilbert(input_ids=input_ids,
|
distilbert_output = self.distilbert(input_ids=input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user