From b6992b7b476fe7e231c8e144e36582fbbde0b4d4 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Sat, 31 Aug 2019 00:33:11 -0400 Subject: [PATCH] Applied patch to OpenAI GPT, RoBERTa, TransfoL, XLM and XLNet --- pytorch_transformers/modeling_openai.py | 25 +++++++-------------- pytorch_transformers/modeling_roberta.py | 4 ++-- pytorch_transformers/modeling_transfo_xl.py | 9 +++----- pytorch_transformers/modeling_xlm.py | 17 +++++++------- pytorch_transformers/modeling_xlnet.py | 13 +++++------ 5 files changed, 27 insertions(+), 41 deletions(-) diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index 78e57b0c59..8bf9d86696 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -249,14 +249,15 @@ class Attention(nn.Module): self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) - self.pruned_heads = [] + self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return mask = torch.ones(self.n_head, self.split_size // self.n_head) + heads = set(heads) - self.pruned_heads for head in heads: - head -= len(list(filter(lambda h: h < head, self.pruned_heads))) + head -= sum(1 if h < head else 0 for h in self.pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -267,7 +268,7 @@ class Attention(nn.Module): # Update hyper params self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.n_head = self.n_head - len(heads) - self.pruned_heads.extend(heads) + self.pruned_heads = self.pruned_heads.union(heads) def _attn(self, q, k, v, head_mask=None): w = torch.matmul(q, k) @@ -366,10 +367,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_openai_gpt base_model_prefix = "transformer" - def __init__(self, *inputs, **kwargs): - super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs) - - def init_weights(self, module): + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): @@ -459,14 +457,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) - if hasattr(config, "pruned_heads"): - pruned_heads = config.pruned_heads.copy().items() - config.pruned_heads = {} - for layer, heads in pruned_heads: - if self.h[int(layer)].attn.n_head == config.n_head: - self.prune_heads({int(layer): list(map(int, heads))}) - - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens) @@ -579,7 +570,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): self.transformer = OpenAIGPTModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -686,7 +677,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.multiple_choice_head = SequenceSummary(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): diff --git a/pytorch_transformers/modeling_roberta.py b/pytorch_transformers/modeling_roberta.py index cbd88ab86e..6ae5cd1d44 100644 --- a/pytorch_transformers/modeling_roberta.py +++ b/pytorch_transformers/modeling_roberta.py @@ -168,7 +168,7 @@ class RobertaModel(BertModel): super(RobertaModel, self).__init__(config) self.embeddings = RobertaEmbeddings(config) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None): if input_ids[:, 0].sum().item() != 0: @@ -220,7 +220,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): self.roberta = RobertaModel(config) self.lm_head = RobertaLMHead(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 283fa66daf..0c5c5b7798 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -853,9 +853,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_transfo_xl base_model_prefix = "transformer" - def __init__(self, *inputs, **kwargs): - super(TransfoXLPreTrainedModel, self).__init__(*inputs, **kwargs) - def _init_weight(self, weight): if self.config.init == 'uniform': nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) @@ -865,7 +862,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel): def _init_bias(self, bias): nn.init.constant_(bias, 0.0) - def init_weights(self, m): + def _init_weights(self, m): """ Initialize the weights. """ classname = m.__class__.__name__ @@ -1059,7 +1056,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): self.r_emb = nn.Parameter(torch.FloatTensor( self.n_layer, self.max_klen, self.n_head, self.d_head)) - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): return self.word_emb @@ -1306,7 +1303,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): else: self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 17e39528f8..9eff09b362 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -271,15 +271,16 @@ class MultiHeadAttention(nn.Module): self.k_lin = nn.Linear(dim, dim) self.v_lin = nn.Linear(dim, dim) self.out_lin = nn.Linear(dim, dim) - self.pruned_heads = [] + self.pruned_heads = set() def prune_heads(self, heads): attention_head_size = self.dim // self.n_heads if len(heads) == 0: return mask = torch.ones(self.n_heads, attention_head_size) + heads = set(heads) - self.pruned_heads for head in heads: - head -= len(list(filter(lambda h: h < head, self.pruned_heads))) + head -= sum(1 if h < head else 0 for h in self.pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index = torch.arange(len(mask))[mask].long() @@ -291,7 +292,7 @@ class MultiHeadAttention(nn.Module): # Update hyper params self.n_heads = self.n_heads - len(heads) self.dim = attention_head_size * self.n_heads - self.pruned_heads.extend(heads) + self.pruned_heads = self.pruned_heads.union(heads) def forward(self, input, mask, kv=None, cache=None, head_mask=None): """ @@ -386,7 +387,7 @@ class XLMPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs) - def init_weights(self, module): + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, nn.Embedding): if self.config is not None and self.config.embed_init_std is not None: @@ -569,7 +570,7 @@ class XLMModel(XLMPreTrainedModel): if self.attentions[int(layer)].n_heads == config.n_heads: self.prune_heads({int(layer): list(map(int, heads))}) - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens) @@ -781,7 +782,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): self.transformer = XLMModel(config) self.pred_layer = XLMPredLayer(config) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -843,7 +844,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel): self.transformer = XLMModel(config) self.sequence_summary = SequenceSummary(config) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, labels=None, head_mask=None): @@ -921,7 +922,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): self.transformer = XLMModel(config) self.qa_outputs = SQuADHead(config) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, attention_mask=None, cache=None, start_positions=None, end_positions=None, diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index cc9c1379a1..516e87e99b 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -586,10 +586,7 @@ class XLNetPreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_xlnet base_model_prefix = "transformer" - def __init__(self, *inputs, **kwargs): - super(XLNetPreTrainedModel, self).__init__(*inputs, **kwargs) - - def init_weights(self, module): + def _init_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): @@ -736,7 +733,7 @@ class XLNetModel(XLNetPreTrainedModel): self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)]) self.dropout = nn.Dropout(config.dropout) - self.apply(self.init_weights) + self.init_weights() def _resize_token_embeddings(self, new_num_tokens): self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens) @@ -1037,7 +1034,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): self.transformer = XLNetModel(config) self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) - self.apply(self.init_weights) + self.init_weights() self.tie_weights() def tie_weights(self): @@ -1114,7 +1111,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): self.sequence_summary = SequenceSummary(config) self.logits_proj = nn.Linear(config.d_model, config.num_labels) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, @@ -1216,7 +1213,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): self.end_logits = PoolerEndLogits(config) self.answer_class = PoolerAnswerClass(config) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,