updated pruning logic with sets - Bert and GPT-2
This commit is contained in:
@@ -337,26 +337,30 @@ class BertAttention(nn.Module):
|
||||
super(BertAttention, self).__init__()
|
||||
self.self = BertSelfAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
self.pruned_heads = []
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
||||
for head in heads:
|
||||
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
|
||||
# Compute how many pruned heads are before the head and move the index accordingly
|
||||
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
|
||||
# Prune linear layers
|
||||
self.self.query = prune_linear_layer(self.self.query, index)
|
||||
self.self.key = prune_linear_layer(self.self.key, index)
|
||||
self.self.value = prune_linear_layer(self.self.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
# Update hyper params
|
||||
|
||||
# Update hyper params and store pruned heads
|
||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads.extend(heads)
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(self, input_tensor, attention_mask, head_mask=None):
|
||||
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
||||
@@ -534,12 +538,8 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
load_tf_weights = load_tf_weights_in_bert
|
||||
base_model_prefix = "bert"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
|
||||
def init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
@@ -652,14 +652,7 @@ class BertModel(BertPreTrainedModel):
|
||||
self.encoder = BertEncoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
if hasattr(config, "pruned_heads"):
|
||||
pruned_heads = config.pruned_heads.copy().items()
|
||||
config.pruned_heads = {}
|
||||
for layer, heads in pruned_heads:
|
||||
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
@@ -768,7 +761,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertPreTrainingHeads(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
@@ -836,7 +829,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertOnlyMLMHead(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
@@ -901,7 +894,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertOnlyNSPHead(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
|
||||
position_ids=None, head_mask=None):
|
||||
@@ -962,7 +955,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||
position_ids=None, head_mask=None):
|
||||
@@ -1066,7 +1059,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||
position_ids=None, head_mask=None):
|
||||
@@ -1134,7 +1127,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||
position_ids=None, head_mask=None):
|
||||
@@ -1208,7 +1201,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
self.bert = BertModel(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
|
||||
end_positions=None, position_ids=None, head_mask=None):
|
||||
|
||||
Reference in New Issue
Block a user