Applied patch to OpenAI GPT, RoBERTa, TransfoL, XLM and XLNet

This commit is contained in:
LysandreJik
2019-08-31 00:33:11 -04:00
parent bdb4409ed8
commit b6992b7b47
5 changed files with 27 additions and 41 deletions

View File

@@ -271,15 +271,16 @@ class MultiHeadAttention(nn.Module):
self.k_lin = nn.Linear(dim, dim)
self.v_lin = nn.Linear(dim, dim)
self.out_lin = nn.Linear(dim, dim)
self.pruned_heads = []
self.pruned_heads = set()
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
@@ -291,7 +292,7 @@ class MultiHeadAttention(nn.Module):
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
self.pruned_heads.extend(heads)
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
"""
@@ -386,7 +387,7 @@ class XLMPreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs):
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
def _init_weights(self, module):
""" Initialize the weights. """
if isinstance(module, nn.Embedding):
if self.config is not None and self.config.embed_init_std is not None:
@@ -569,7 +570,7 @@ class XLMModel(XLMPreTrainedModel):
if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))})
self.apply(self.init_weights)
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
@@ -781,7 +782,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self.transformer = XLMModel(config)
self.pred_layer = XLMPredLayer(config)
self.apply(self.init_weights)
self.init_weights()
self.tie_weights()
def tie_weights(self):
@@ -843,7 +844,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self.transformer = XLMModel(config)
self.sequence_summary = SequenceSummary(config)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, labels=None, head_mask=None):
@@ -921,7 +922,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self.transformer = XLMModel(config)
self.qa_outputs = SQuADHead(config)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, start_positions=None, end_positions=None,