From bdb4409ed8de4d199907c75832398f2c49a564e1 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 31 Aug 2019 01:59:07 +0200 Subject: [PATCH] updated pruning logic with sets - Bert and GPT-2 --- pytorch_transformers/modeling_bert.py | 43 +++++++++++--------------- pytorch_transformers/modeling_gpt2.py | 25 +++++++-------- pytorch_transformers/modeling_utils.py | 38 ++++++++++------------- 3 files changed, 45 insertions(+), 61 deletions(-) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index 9aa25edbe3..e2d8346071 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -337,26 +337,30 @@ class BertAttention(nn.Module): super(BertAttention, self).__init__() self.self = BertSelfAttention(config) self.output = BertSelfOutput(config) - self.pruned_heads = [] + self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads for head in heads: - head -= len(list(filter(lambda h: h < head, self.pruned_heads))) + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() + # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - # Update hyper params + + # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads.extend(heads) + self.pruned_heads = self.pruned_heads.union(heads) def forward(self, input_tensor, attention_mask, head_mask=None): self_outputs = self.self(input_tensor, attention_mask, head_mask) @@ -534,12 +538,8 @@ class BertPreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" - def __init__(self, *inputs, **kwargs): - super(BertPreTrainedModel, self).__init__(*inputs, **kwargs) - - def init_weights(self, module): - """ Initialize the weights. - """ + def _init_weights(self, module): + """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 @@ -652,14 +652,7 @@ class BertModel(BertPreTrainedModel): self.encoder = BertEncoder(config) self.pooler = BertPooler(config) - if hasattr(config, "pruned_heads"): - pruned_heads = config.pruned_heads.copy().items() - config.pruned_heads = {} - for layer, heads in pruned_heads: - if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads: - self.prune_heads({int(layer): list(map(int, heads))}) - - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): old_embeddings = self.embeddings.word_embeddings @@ -768,7 +761,7 @@ class BertForPreTraining(BertPreTrainedModel): self.bert = BertModel(config) self.cls = BertPreTrainingHeads(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -836,7 +829,7 @@ class BertForMaskedLM(BertPreTrainedModel): self.bert = BertModel(config) self.cls = BertOnlyMLMHead(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -901,7 +894,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): self.bert = BertModel(config) self.cls = BertOnlyNSPHead(config) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, position_ids=None, head_mask=None): @@ -962,7 +955,7 @@ class BertForSequenceClassification(BertPreTrainedModel): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None): @@ -1066,7 +1059,7 @@ class BertForMultipleChoice(BertPreTrainedModel): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None): @@ -1134,7 +1127,7 @@ class BertForTokenClassification(BertPreTrainedModel): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None): @@ -1208,7 +1201,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): self.bert = BertModel(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, position_ids=None, head_mask=None): diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 8b39ad372e..017ad4f7b4 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -233,25 +233,29 @@ class Attention(nn.Module): self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) - self.pruned_heads = [] + self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.n_head, self.split_size // self.n_head) + heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads for head in heads: - head -= len(list(filter(lambda h: h < head, self.pruned_heads))) + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)]) + # Prune conv1d layers self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + # Update hyper params self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.n_head = self.n_head - len(heads) - self.pruned_heads.extend(heads) + self.pruned_heads = self.pruned_heads.union(heads) def _attn(self, q, k, v, head_mask=None): w = torch.matmul(q, k) @@ -357,7 +361,7 @@ class GPT2PreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs) - def init_weights(self, module): + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): @@ -456,14 +460,7 @@ class GPT2Model(GPT2PreTrainedModel): self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - if hasattr(config, "pruned_heads"): - pruned_heads = config.pruned_heads.copy().items() - config.pruned_heads = {} - for layer, heads in pruned_heads: - if self.h[int(layer)].attn.n_head == config.n_head: - self.prune_heads({int(layer): list(map(int, heads))}) - - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): self.wte = self._get_resized_embeddings(self.wte, new_num_tokens) @@ -594,7 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): self.transformer = GPT2Model(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -718,7 +715,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.multiple_choice_head = SequenceSummary(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index c69cba49e3..33bcb968b5 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -202,8 +202,7 @@ class PretrainedConfig(object): config = cls.from_json_file(resolved_config_file) if hasattr(config, 'pruned_heads'): - config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()} - + config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items()) # Update config with kwargs if needed to_remove = [] @@ -316,7 +315,7 @@ class PreTrainedModel(nn.Module): new_embeddings.to(old_embeddings.weight.device) # initialize all new embeddings (in particular added tokens) - self.init_weights(new_embeddings) + self._init_weights(new_embeddings) # Copy word embeddings from the previous weights num_tokens_to_copy = min(old_num_tokens, new_num_tokens) @@ -360,36 +359,31 @@ class PreTrainedModel(nn.Module): return model_embeds + def init_weights(self): + """ Initialize and prunes weights if needed. """ + # Initialize weights + self.apply(self._init_weights) + + # Prune heads if needed + if self.config.pruned_heads: + self.prune_heads(self.config.pruned_heads) + def prune_heads(self, heads_to_prune): """ Prunes heads of the base model. Arguments: heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). + E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. """ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed - to_be_pruned = {} - + # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads for layer, heads in heads_to_prune.items(): - if int(layer) not in self.config.pruned_heads: - self.config.pruned_heads[int(layer)] = heads - to_be_pruned[int(layer)] = heads - else: - for head in heads: - if head not in self.config.pruned_heads[int(layer)]: - self.config.pruned_heads[int(layer)].append(head) + union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) + self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON - if int(layer) in to_be_pruned: - to_be_pruned[int(layer)].append(head) - else: - to_be_pruned[int(layer)] = [head] - else: - logger.warning("Tried to remove head " + str(head) + - " of layer " + str(layer) + - " but it was already removed. The current removed heads are " + str(heads_to_prune)) - - base_model._prune_heads(to_be_pruned) + base_model._prune_heads(heads_to_prune) def save_pretrained(self, save_directory): """ Save a model and its configuration file to a directory, so that it