Merge pull request #1077 from huggingface/pruning-save-and-load

Pruning changes so that deleted heads are kept on save/load
This commit is contained in:
Thomas Wolf
2019-09-01 09:42:15 +02:00
committed by GitHub
10 changed files with 213 additions and 58 deletions

View File

@@ -271,13 +271,16 @@ class MultiHeadAttention(nn.Module):
self.k_lin = nn.Linear(dim, dim)
self.v_lin = nn.Linear(dim, dim)
self.out_lin = nn.Linear(dim, dim)
self.pruned_heads = set()
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
@@ -289,6 +292,7 @@ class MultiHeadAttention(nn.Module):
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
"""
@@ -383,7 +387,7 @@ class XLMPreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs):
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
def _init_weights(self, module):
""" Initialize the weights. """
if isinstance(module, nn.Embedding):
if self.config is not None and self.config.embed_init_std is not None:
@@ -559,7 +563,14 @@ class XLMModel(XLMPreTrainedModel):
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
self.apply(self.init_weights)
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))})
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
@@ -771,7 +782,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self.transformer = XLMModel(config)
self.pred_layer = XLMPredLayer(config)
self.apply(self.init_weights)
self.init_weights()
self.tie_weights()
def tie_weights(self):
@@ -833,7 +844,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self.transformer = XLMModel(config)
self.sequence_summary = SequenceSummary(config)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, labels=None, head_mask=None):
@@ -911,7 +922,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self.transformer = XLMModel(config)
self.qa_outputs = SQuADHead(config)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, start_positions=None, end_positions=None,