Applied patch to OpenAI GPT, RoBERTa, TransfoL, XLM and XLNet
This commit is contained in:
@@ -249,14 +249,15 @@ class Attention(nn.Module):
|
|||||||
self.c_proj = Conv1D(n_state, nx)
|
self.c_proj = Conv1D(n_state, nx)
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
self.pruned_heads = []
|
self.pruned_heads = set()
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
||||||
|
heads = set(heads) - self.pruned_heads
|
||||||
for head in heads:
|
for head in heads:
|
||||||
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
|
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
@@ -267,7 +268,7 @@ class Attention(nn.Module):
|
|||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
||||||
self.n_head = self.n_head - len(heads)
|
self.n_head = self.n_head - len(heads)
|
||||||
self.pruned_heads.extend(heads)
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def _attn(self, q, k, v, head_mask=None):
|
def _attn(self, q, k, v, head_mask=None):
|
||||||
w = torch.matmul(q, k)
|
w = torch.matmul(q, k)
|
||||||
@@ -366,10 +367,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
|||||||
load_tf_weights = load_tf_weights_in_openai_gpt
|
load_tf_weights = load_tf_weights_in_openai_gpt
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def _init_weights(self, module):
|
||||||
super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs)
|
|
||||||
|
|
||||||
def init_weights(self, module):
|
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
||||||
@@ -459,14 +457,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.drop = nn.Dropout(config.embd_pdrop)
|
self.drop = nn.Dropout(config.embd_pdrop)
|
||||||
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
||||||
|
|
||||||
if hasattr(config, "pruned_heads"):
|
self.init_weights()
|
||||||
pruned_heads = config.pruned_heads.copy().items()
|
|
||||||
config.pruned_heads = {}
|
|
||||||
for layer, heads in pruned_heads:
|
|
||||||
if self.h[int(layer)].attn.n_head == config.n_head:
|
|
||||||
self.prune_heads({int(layer): list(map(int, heads))})
|
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
|
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
|
||||||
@@ -579,7 +570,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.transformer = OpenAIGPTModel(config)
|
self.transformer = OpenAIGPTModel(config)
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
@@ -686,7 +677,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
self.multiple_choice_head = SequenceSummary(config)
|
self.multiple_choice_head = SequenceSummary(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ class RobertaModel(BertModel):
|
|||||||
super(RobertaModel, self).__init__(config)
|
super(RobertaModel, self).__init__(config)
|
||||||
|
|
||||||
self.embeddings = RobertaEmbeddings(config)
|
self.embeddings = RobertaEmbeddings(config)
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
|
||||||
if input_ids[:, 0].sum().item() != 0:
|
if input_ids[:, 0].sum().item() != 0:
|
||||||
@@ -220,7 +220,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
|||||||
self.roberta = RobertaModel(config)
|
self.roberta = RobertaModel(config)
|
||||||
self.lm_head = RobertaLMHead(config)
|
self.lm_head = RobertaLMHead(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
|
|||||||
@@ -853,9 +853,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
|
|||||||
load_tf_weights = load_tf_weights_in_transfo_xl
|
load_tf_weights = load_tf_weights_in_transfo_xl
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
|
||||||
super(TransfoXLPreTrainedModel, self).__init__(*inputs, **kwargs)
|
|
||||||
|
|
||||||
def _init_weight(self, weight):
|
def _init_weight(self, weight):
|
||||||
if self.config.init == 'uniform':
|
if self.config.init == 'uniform':
|
||||||
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
|
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
|
||||||
@@ -865,7 +862,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
|
|||||||
def _init_bias(self, bias):
|
def _init_bias(self, bias):
|
||||||
nn.init.constant_(bias, 0.0)
|
nn.init.constant_(bias, 0.0)
|
||||||
|
|
||||||
def init_weights(self, m):
|
def _init_weights(self, m):
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
classname = m.__class__.__name__
|
classname = m.__class__.__name__
|
||||||
@@ -1059,7 +1056,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
self.r_emb = nn.Parameter(torch.FloatTensor(
|
self.r_emb = nn.Parameter(torch.FloatTensor(
|
||||||
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
return self.word_emb
|
return self.word_emb
|
||||||
@@ -1306,7 +1303,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
||||||
config.cutoffs, div_val=config.div_val)
|
config.cutoffs, div_val=config.div_val)
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
|
|||||||
@@ -271,15 +271,16 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.k_lin = nn.Linear(dim, dim)
|
self.k_lin = nn.Linear(dim, dim)
|
||||||
self.v_lin = nn.Linear(dim, dim)
|
self.v_lin = nn.Linear(dim, dim)
|
||||||
self.out_lin = nn.Linear(dim, dim)
|
self.out_lin = nn.Linear(dim, dim)
|
||||||
self.pruned_heads = []
|
self.pruned_heads = set()
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
attention_head_size = self.dim // self.n_heads
|
attention_head_size = self.dim // self.n_heads
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_heads, attention_head_size)
|
mask = torch.ones(self.n_heads, attention_head_size)
|
||||||
|
heads = set(heads) - self.pruned_heads
|
||||||
for head in heads:
|
for head in heads:
|
||||||
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
|
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
@@ -291,7 +292,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.n_heads = self.n_heads - len(heads)
|
self.n_heads = self.n_heads - len(heads)
|
||||||
self.dim = attention_head_size * self.n_heads
|
self.dim = attention_head_size * self.n_heads
|
||||||
self.pruned_heads.extend(heads)
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
|
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
@@ -386,7 +387,7 @@ class XLMPreTrainedModel(PreTrainedModel):
|
|||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
|
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||||
|
|
||||||
def init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights. """
|
""" Initialize the weights. """
|
||||||
if isinstance(module, nn.Embedding):
|
if isinstance(module, nn.Embedding):
|
||||||
if self.config is not None and self.config.embed_init_std is not None:
|
if self.config is not None and self.config.embed_init_std is not None:
|
||||||
@@ -569,7 +570,7 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
if self.attentions[int(layer)].n_heads == config.n_heads:
|
if self.attentions[int(layer)].n_heads == config.n_heads:
|
||||||
self.prune_heads({int(layer): list(map(int, heads))})
|
self.prune_heads({int(layer): list(map(int, heads))})
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
|
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
|
||||||
@@ -781,7 +782,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.pred_layer = XLMPredLayer(config)
|
self.pred_layer = XLMPredLayer(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
@@ -843,7 +844,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = SequenceSummary(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
||||||
attention_mask=None, cache=None, labels=None, head_mask=None):
|
attention_mask=None, cache=None, labels=None, head_mask=None):
|
||||||
@@ -921,7 +922,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
|||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.qa_outputs = SQuADHead(config)
|
self.qa_outputs = SQuADHead(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
|
||||||
attention_mask=None, cache=None, start_positions=None, end_positions=None,
|
attention_mask=None, cache=None, start_positions=None, end_positions=None,
|
||||||
|
|||||||
@@ -586,10 +586,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
|
|||||||
load_tf_weights = load_tf_weights_in_xlnet
|
load_tf_weights = load_tf_weights_in_xlnet
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def _init_weights(self, module):
|
||||||
super(XLNetPreTrainedModel, self).__init__(*inputs, **kwargs)
|
|
||||||
|
|
||||||
def init_weights(self, module):
|
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
@@ -736,7 +733,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
|
self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens)
|
self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens)
|
||||||
@@ -1037,7 +1034,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
self.transformer = XLNetModel(config)
|
self.transformer = XLNetModel(config)
|
||||||
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
@@ -1114,7 +1111,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = SequenceSummary(config)
|
||||||
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
|
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None,
|
mems=None, perm_mask=None, target_mapping=None,
|
||||||
@@ -1216,7 +1213,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
self.end_logits = PoolerEndLogits(config)
|
self.end_logits = PoolerEndLogits(config)
|
||||||
self.answer_class = PoolerAnswerClass(config)
|
self.answer_class = PoolerAnswerClass(config)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None,
|
mems=None, perm_mask=None, target_mapping=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user