From 2b56e9889284b5432881e947aefbf7ed6780e4ec Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 28 Jun 2019 16:35:09 +0200 Subject: [PATCH] standardizing API across models - XLNetForSeqClass working --- examples/run_xlnet_classifier.py | 59 ++++--- pytorch_pretrained_bert/modeling.py | 137 +++++++-------- pytorch_pretrained_bert/modeling_xlm.py | 132 ++++++++------- pytorch_pretrained_bert/modeling_xlnet.py | 196 ++++++++++++---------- 4 files changed, 277 insertions(+), 247 deletions(-) diff --git a/examples/run_xlnet_classifier.py b/examples/run_xlnet_classifier.py index fb5501e370..e30cad773b 100644 --- a/examples/run_xlnet_classifier.py +++ b/examples/run_xlnet_classifier.py @@ -67,6 +67,8 @@ def main(): help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") + parser.add_argument("--max_steps", default=-1, type=int, + help="If > 0 limit the number of training steps to perform, you should choose only one of num_train_epochs and max_steps.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") @@ -189,8 +191,7 @@ def main(): model = torch.nn.DataParallel(model) global_step = 0 - nb_tr_steps = 0 - tr_loss = 0 + curr_tr_loss, curr_steps = 0., 1 if args.do_train: if args.local_rank in [-1, 0]: @@ -229,12 +230,15 @@ def main(): train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) if args.local_rank == -1: - train_sampler = RandomSampler(train_data) + train_sampler = SequentialSampler(train_data) # RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) - num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + if args.max_steps > 0: + num_train_optimization_steps = args.max_steps + else: + num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer @@ -275,22 +279,16 @@ def main(): logger.info(" Num steps = %d", num_train_optimization_steps) model.train() - for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]): - tr_loss = 0 - nb_tr_examples, nb_tr_steps = 0, 0 - for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): + for _ in trange(int(args.num_train_epochs) if args.max_steps <= 0 else int('Inf'), + desc="Epoch", disable=args.local_rank not in [-1, 0]): + for step, batch in enumerate(tqdm(train_dataloader, + desc="Iteration", + disable=args.local_rank not in [-1, 0])): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch # define a new function to compute loss values for both output_modes - logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) - - if output_mode == "classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) - elif output_mode == "regression": - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), label_ids.view(-1)) + loss, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. @@ -302,12 +300,10 @@ def main(): else: loss.backward() - if args.clip_gradients > 0.0: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradients) + gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradients) - tr_loss += loss.item() - nb_tr_examples += input_ids.size(0) - nb_tr_steps += 1 + curr_tr_loss += loss.item() + curr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: # modify learning rate with special warm up BERT uses @@ -318,10 +314,19 @@ def main(): optimizer.step() optimizer.zero_grad() global_step += 1 - if args.local_rank in [-1, 0] and (args.log_every <= 0 or (step + 1) % args.log_every == 0): - if not args.fp16: - tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) - tb_writer.add_scalar('loss', loss.item(), global_step) + if args.local_rank in [-1, 0] and (args.log_every <= 0 or (global_step + 1) % args.log_every == 0): + learning_rate = optimizer.get_lr()[0] if not args.fp16 else lr_this_step + logger.info("[{}] | gnorm {:.2f} lr {:8.6f} | loss {:.2f}".format( + global_step, gnorm, learning_rate, curr_tr_loss / curr_steps)) + tb_writer.add_scalar('lr', learning_rate, global_step) + tb_writer.add_scalar('loss', curr_tr_loss / curr_steps, global_step) + curr_tr_loss, curr_steps = 0., 1 + + if args.max_steps > 0 and global_step > args.max_steps: + break + + if args.max_steps > 0 and global_step > args.max_steps: + break ### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() ### Example: @@ -435,7 +440,7 @@ def main(): preds = np.squeeze(preds) result = compute_metrics(task_name, preds, out_label_ids) - loss = tr_loss/global_step if args.do_train else None + loss = curr_tr_loss/curr_steps if args.do_train else None result['eval_loss'] = eval_loss result['global_step'] = global_step @@ -508,7 +513,7 @@ def main(): preds = np.argmax(preds, axis=1) result = compute_metrics(task_name, preds, out_label_ids) - loss = tr_loss/global_step if args.do_train else None + loss = curr_tr_loss/curr_steps if args.do_train else None result['eval_loss'] = eval_loss result['global_step'] = global_step diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 27c747e405..eade7310f9 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -270,15 +270,13 @@ class BertEmbeddings(nn.Module): class BertSelfAttention(nn.Module): - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False): super(BertSelfAttention, self).__init__() if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.output_attentions = output_attentions - self.keep_multihead_output = keep_multihead_output - self.multihead_output = None self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) @@ -329,9 +327,9 @@ class BertSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - if self.output_attentions: - return attention_probs, context_layer - return context_layer + + outputs = [context_layer, attention_probs] if self.output_attentions else [context_layer] + return outputs class BertSelfOutput(nn.Module): @@ -349,11 +347,10 @@ class BertSelfOutput(nn.Module): class BertAttention(nn.Module): - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False): super(BertAttention, self).__init__() self.output_attentions = output_attentions - self.self = BertSelfAttention(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + self.self = BertSelfAttention(config, output_attentions=output_attentions) self.output = BertSelfOutput(config) def prune_heads(self, heads): @@ -374,13 +371,10 @@ class BertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads def forward(self, input_tensor, attention_mask, head_mask=None): - self_output = self.self(input_tensor, attention_mask, head_mask) - if self.output_attentions: - attentions, self_output = self_output - attention_output = self.output(self_output, input_tensor) - if self.output_attentions: - return attentions, attention_output - return attention_output + self_outputs = self.self(input_tensor, attention_mask, head_mask) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = [attention_output] + self_outputs[1:] # add attentions if we output them + return outputs class BertIntermediate(nn.Module): @@ -413,48 +407,52 @@ class BertOutput(nn.Module): class BertLayer(nn.Module): - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False): super(BertLayer, self).__init__() self.output_attentions = output_attentions - self.attention = BertAttention(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + self.attention = BertAttention(config, output_attentions=output_attentions) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward(self, hidden_states, attention_mask, head_mask=None): - attention_output = self.attention(hidden_states, attention_mask, head_mask) - if self.output_attentions: - attentions, attention_output = attention_output - intermediate_output = self.intermediate(attention_output) + attention_outputs = self.attention(hidden_states, attention_mask, head_mask) + intermediate_output = self.intermediate(attention_outputs[0]) layer_output = self.output(intermediate_output, attention_output) - if self.output_attentions: - return attentions, layer_output - return layer_output + outputs = [layer_output] + attention_outputs[1:] # add attentions if we output them + return outputs class BertEncoder(nn.Module): - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False, output_hidden_states=False): super(BertEncoder, self).__init__() self.output_attentions = output_attentions - layer = BertLayer(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + self.output_hidden_states = output_hidden_states + layer = BertLayer(config, output_attentions=output_attentions) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) - def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None): - all_encoder_layers = [] + def forward(self, hidden_states, attention_mask, head_mask=None): + all_hidden_states = [] all_attentions = [] for i, layer_module in enumerate(self.layer): - hidden_states = layer_module(hidden_states, attention_mask, head_mask[i]) + if self.output_hidden_states: + all_hidden_states.append(hidden_states) + + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) + hidden_states = layer_outputs[0] + if self.output_attentions: - attentions, hidden_states = hidden_states - all_attentions.append(attentions) - if output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - if not output_all_encoded_layers: - all_encoder_layers.append(hidden_states) + all_attentions.append(layer_outputs[1]) + + # Add last layer + if self.output_hidden_states: + all_hidden_states.append(hidden_states) + + outputs = [hidden_states] + if self.output_hidden_states: + outputs.append(all_hidden_states) if self.output_attentions: - return all_attentions, all_encoder_layers - return all_encoder_layers + outputs.append(all_attentions) + return outputs # outputs, (hidden states), (attentions) class BertPooler(nn.Module): @@ -617,12 +615,13 @@ class BertModel(BertPreTrainedModel): all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False, output_hidden_states=False): super(BertModel, self).__init__(config) self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + output_hidden_states=output_hidden_states) self.pooler = BertPooler(config) self.apply(self.init_weights) @@ -633,13 +632,7 @@ class BertModel(BertPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - def get_multihead_outputs(self): - """ Gather all multi-head outputs. - Return: list (layers) of multihead module outputs with gradients - """ - return [layer.attention.self.multihead_output for layer in self.encoder.layer] - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, head_mask=None): + def forward(self, input_ids, token_type_ids=None, attention_mask=None, head_mask=None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: @@ -676,19 +669,14 @@ class BertModel(BertPreTrainedModel): head_mask = [None] * self.config.num_hidden_layers embedding_output = self.embeddings(input_ids, token_type_ids) - encoded_layers = self.encoder(embedding_output, - extended_attention_mask, - output_all_encoded_layers=output_all_encoded_layers, - head_mask=head_mask) - if self.output_attentions: - all_attentions, encoded_layers = encoded_layers - sequence_output = encoded_layers[-1] + encoder_outputs = self.encoder(embedding_output, + extended_attention_mask, + head_mask=head_mask) + sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) - if not output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - if self.output_attentions: - return all_attentions, encoded_layers, pooled_output - return encoded_layers, pooled_output + + outputs = [sequence_output, pooled_output] + encoder_outputs[1:] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) class BertForPreTraining(BertPreTrainedModel): @@ -746,32 +734,33 @@ class BertForPreTraining(BertPreTrainedModel): masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False, output_hidden_states=False): super(BertForPreTraining, self).__init__(config) self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.bert = BertModel(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + output_hidden_states=output_hidden_states) self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) self.apply(self.init_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None): - outputs = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False, head_mask=head_mask) - if self.output_attentions: - all_attentions, sequence_output, pooled_output = outputs - else: - sequence_output, pooled_output = outputs + def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, + next_sentence_label=None, head_mask=None): + outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask) + + sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + outputs = [prediction_scores, seq_relationship_score] + outputs[2:] # add hidden states and attention if they are here + if masked_lm_labels is not None and next_sentence_label is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) total_loss = masked_lm_loss + next_sentence_loss - return total_loss - elif self.output_attentions: - return all_attentions, prediction_scores, seq_relationship_score - return prediction_scores, seq_relationship_score + outputs = [total_loss] + outputs + + return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) class BertForMaskedLM(BertPreTrainedModel): diff --git a/pytorch_pretrained_bert/modeling_xlm.py b/pytorch_pretrained_bert/modeling_xlm.py index 8cb56de253..92e1cc124c 100644 --- a/pytorch_pretrained_bert/modeling_xlm.py +++ b/pytorch_pretrained_bert/modeling_xlm.py @@ -919,9 +919,11 @@ class XLMModel(XLMPreTrainedModel): class XLMModel(XLMPreTrainedModel): - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False, output_hidden_states=False): super(XLMModel, self).__init__(config) self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.mem_len = config.mem_len self.reuse_len = config.reuse_len self.d_model = config.d_model @@ -1038,8 +1040,7 @@ class XLMModel(XLMPreTrainedModel): return pos_emb def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, - mems=None, perm_mask=None, target_mapping=None, inp_q=None, - output_all_encoded_layers=True, head_mask=None): + mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None): """ Args: inp_k: int32 Tensor in shape [bsz, len], the input token IDs. @@ -1188,23 +1189,45 @@ class XLMModel(XLMPreTrainedModel): mems = [None] * len(self.layer) hidden_states = [] + attentions = [] for i, layer_module in enumerate(self.layer): # cache new mems new_mems.append(self.cache_mem(output_h, mems[i])) + # Save hidden_states + if output_g is None: + hidden_states.append(output_h) + else: + hidden_states.append((output_h, output_g)) output_h, output_g = layer_module(output_h, output_g, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask, r=pos_emb, seg_mat=seg_mat, mems=mems[i], target_mapping=target_mapping, head_mask=head_mask) + # Save last hidden_state + if output_g is None: hidden_states.append(output_h) + else: + hidden_states.append((output_h, output_g)) + + # Select the right output and add dropout output = self.dropout(output_g if output_g is not None else output_h) # We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) output = output.permute(1, 0, 2).contiguous() - hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states] + if output_g is None: + hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states] + else: + hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs] - return output, hidden_states, new_mems + # Build the list of outputs + outputs = [output, new_mems] + if self.output_attentions: + outputs.append(attentions) + if self.output_hidden_states: + outputs.append(hidden_states) + + return outputs class XLMPredLayer(nn.Module): @@ -1309,14 +1332,15 @@ class XLMLMHeadModel(XLMPreTrainedModel): all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False, output_hidden_states=False): super(XLMLMHeadModel, self).__init__(config) self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.attn_type = config.attn_type self.same_length = config.same_length - self.transformer = XLMModel(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states) self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) # Tie weights @@ -1331,7 +1355,7 @@ class XLMLMHeadModel(XLMPreTrainedModel): def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None, - labels=None, output_all_encoded_layers=True, head_mask=None): + labels=None, head_mask=None): """ Args: inp_k: int32 Tensor in shape [bsz, len], the input token IDs. @@ -1358,33 +1382,28 @@ class XLMLMHeadModel(XLMPreTrainedModel): summary_type: str, "last", "first", "mean", or "attn". The method to pool the input to get a vector representation. """ - output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, - mems, perm_mask, target_mapping, inp_q, - output_all_encoded_layers, head_mask) + transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, + mems, perm_mask, target_mapping, inp_q, head_mask) + output = transformer_outputs[0] logits = self.lm_loss(output) + outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here + if labels is not None: # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-1) loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - return loss, new_mems + outputs = [loss] + outputs - # if self.output_attentions: - # all_attentions, encoded_layers = encoded_layers - # sequence_output = encoded_layers[-1] - # pooled_output = self.pooler(sequence_output) - # if not output_all_encoded_layers: - # encoded_layers = encoded_layers[-1] - # if self.output_attentions: - return logits, new_mems - # return all_attentions, encoded_layers, pooled_output + outputs = [logits] + outputs + + return outputs class XLMSequenceSummary(nn.Module): - def __init__(self, config, summary_type="last", use_proj=True, - output_attentions=False, keep_multihead_output=False): + def __init__(self, config, summary_type="last", use_proj=True): super(XLMSequenceSummary, self).__init__() self.summary_type = summary_type if use_proj: @@ -1481,26 +1500,23 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ``` """ def __init__(self, config, summary_type="last", use_proj=True, num_labels=2, - output_attentions=False, keep_multihead_output=False): + output_attentions=False, output_hidden_states=False): super(XLMForSequenceClassification, self).__init__(config) self.output_attentions = output_attentions - self.attn_type = config.attn_type - self.same_length = config.same_length + self.output_hidden_states = output_hidden_states + self.summary_type = summary_type self.num_labels = num_labels - self.transformer = XLMModel(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states) - self.sequence_summary = XLMSequenceSummary(config, summary_type=summary_type, - use_proj=use_proj, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + self.sequence_summary = XLMSequenceSummary(config, summary_type=summary_type, use_proj=use_proj) self.logits_proj = nn.Linear(config.d_model, num_labels) self.apply(self.init_weights) def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None, - labels=None, output_all_encoded_layers=True, head_mask=None): + labels=None, head_mask=None): """ Args: inp_k: int32 Tensor in shape [bsz, len], the input token IDs. @@ -1528,13 +1544,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel): Only used during pretraining for two-stream attention. Set to None during finetuning. """ - output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, - mems, perm_mask, target_mapping, inp_q, - output_all_encoded_layers, head_mask) + transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, + mems, perm_mask, target_mapping, inp_q, head_mask) + output = transformer_outputs[0] output = self.sequence_summary(output) logits = self.logits_proj(output) + outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here + if labels is not None: if self.num_labels == 1: # We are doing regression @@ -1543,17 +1561,11 @@ class XLMForSequenceClassification(XLMPreTrainedModel): else: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss, new_mems + outputs = [loss] + outputs - # if self.output_attentions: - # all_attentions, encoded_layers = encoded_layers - # sequence_output = encoded_layers[-1] - # pooled_output = self.pooler(sequence_output) - # if not output_all_encoded_layers: - # encoded_layers = encoded_layers[-1] - # if self.output_attentions: - return logits, new_mems - # return all_attentions, encoded_layers, pooled_output + outputs = [logits] + outputs + + return outputs class XLMForQuestionAnswering(XLMPreTrainedModel): @@ -1612,27 +1624,30 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): start_logits, end_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False, output_hidden_states=False): super(XLMForQuestionAnswering, self).__init__(config) self.output_attentions = output_attentions - self.transformer = XLMModel(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + self.output_hidden_states = output_hidden_states + + self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states) self.qa_outputs = nn.Linear(config.hidden_size, 2) self.apply(self.init_weights) def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None, - start_positions=None, end_positions=None, - output_all_encoded_layers=True, head_mask=None): - output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, - mems, perm_mask, target_mapping, inp_q, - output_all_encoded_layers, head_mask) + start_positions=None, end_positions=None, head_mask=None): + transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, + mems, perm_mask, target_mapping, inp_q, head_mask) + + output = transformer_outputs[0] logits = self.qa_outputs(output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) + outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here + if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: @@ -1648,7 +1663,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 - return total_loss - elif self.output_attentions: - return all_attentions, start_logits, end_logits - return start_logits, end_logits + outputs = [total_loss] + outputs + + outputs = [start_logits, end_logits] + outputs + + return outputs diff --git a/pytorch_pretrained_bert/modeling_xlnet.py b/pytorch_pretrained_bert/modeling_xlnet.py index c30e263181..71e9f584dd 100644 --- a/pytorch_pretrained_bert/modeling_xlnet.py +++ b/pytorch_pretrained_bert/modeling_xlnet.py @@ -323,16 +323,13 @@ except ImportError: return self.weight * x + self.bias class XLNetRelativeAttention(nn.Module): - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False): super(XLNetRelativeAttention, self).__init__() self.output_attentions = output_attentions if config.d_model % config.n_head != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.d_model, config.n_head)) - self.output_attentions = output_attentions - self.keep_multihead_output = keep_multihead_output - self.multihead_output = None self.n_head = config.n_head self.d_head = config.d_head @@ -368,7 +365,7 @@ class XLNetRelativeAttention(nn.Module): return x - def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None): + def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None, head_mask=None): """Core relative positional attention operations.""" # content based attention score @@ -395,9 +392,16 @@ class XLNetRelativeAttention(nn.Module): attn_prob = F.softmax(attn_score, dim=1) attn_prob = self.dropout(attn_prob) + # Mask heads if we want to + if head_mask is not None: + attn_prob = attn_prob * head_mask + # attention output attn_vec = torch.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h) + if self.output_attentions: + return attn_vec, attn_prob + return attn_vec def post_attention(self, h, attn_vec, residual=True): @@ -439,7 +443,10 @@ class XLNetRelativeAttention(nn.Module): # core attention ops attn_vec_h = self.rel_attn_core( - q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h) + q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask) + + if self.output_attentions: + attn_vec_h, attn_prob_h = attn_vec_h # post processing output_h = self.post_attention(h, attn_vec_h) @@ -452,14 +459,25 @@ class XLNetRelativeAttention(nn.Module): if target_mapping is not None: q_head_g = torch.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping) attn_vec_g = self.rel_attn_core( - q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g) + q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask) + + if self.output_attentions: + attn_vec_g, attn_prob_g = attn_vec_g + attn_vec_g = torch.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping) else: attn_vec_g = self.rel_attn_core( - q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g) + q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask) + + if self.output_attentions: + attn_vec_g, attn_prob_g = attn_vec_g # post processing output_g = self.post_attention(g, attn_vec_g) + + if self.output_attentions: + attn_prob = attn_prob_h, attn_prob_g + else: ###### Multi-head attention with relative positional encoding if mems is not None and mems.dim() > 1: @@ -477,30 +495,18 @@ class XLNetRelativeAttention(nn.Module): # core attention ops attn_vec = self.rel_attn_core( - q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h) + q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask) + + if self.output_attentions: + attn_vec, attn_prob = attn_vec # post processing output_h = self.post_attention(h, attn_vec) output_g = None + if self.output_attentions: + return output_h, output_g, attn_prob - # Mask heads if we want to - # if head_mask is not None: - # attention_probs = attention_probs * head_mask - - # context_layer = torch.matmul(attention_probs, value_layer) - # if self.keep_multihead_output: - # self.multihead_output = context_layer - # self.multihead_output.retain_grad() - - # context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - # new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - # context_layer = context_layer.view(*new_context_layer_shape) - - # if self.output_attentions: - # attentions, self_output = self_output - # if self.output_attentions: - # return attentions, attention_output return output_h, output_g class XLNetFeedForward(nn.Module): @@ -510,7 +516,8 @@ class XLNetFeedForward(nn.Module): self.layer_1 = nn.Linear(config.d_model, config.d_inner) self.layer_2 = nn.Linear(config.d_inner, config.d_model) self.dropout = nn.Dropout(config.dropout) - if isinstance(config.ff_activation, str) or (sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)): + if isinstance(config.ff_activation, str) or \ + (sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode)): self.activation_function = ACT2FN[config.ff_activation] else: self.activation_function = config.ff_activation @@ -526,29 +533,27 @@ class XLNetFeedForward(nn.Module): return output class XLNetLayer(nn.Module): - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False, ): super(XLNetLayer, self).__init__() self.output_attentions = output_attentions - self.rel_attn = XLNetRelativeAttention(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + self.rel_attn = XLNetRelativeAttention(config, output_attentions=output_attentions) self.ff = XLNetFeedForward(config) self.dropout = nn.Dropout(config.dropout) def forward(self, output_h, output_g, attn_mask_h, attn_mask_g, - r, seg_mat, - mems=None, target_mapping=None, head_mask=None): - output_h, output_g = self.rel_attn(output_h, output_g, - attn_mask_h, attn_mask_g, - r, seg_mat, - mems=mems, target_mapping=target_mapping, head_mask=head_mask) + r, seg_mat, mems=None, target_mapping=None, head_mask=None): + outputs = self.rel_attn(output_h, output_g, attn_mask_h, attn_mask_g, + r, seg_mat, mems=mems, target_mapping=target_mapping, + head_mask=head_mask) + output_h, output_g = outputs[:2] + if output_g is not None: output_g = self.ff(output_g) output_h = self.ff(output_h) - # if self.output_attentions: - # return attentions, layer_output - return output_h, output_g + outputs = [output_h, output_g] + outputs[2:] # Add again attentions if there are there + return outputs class XLNetPreTrainedModel(PreTrainedModel): @@ -584,9 +589,11 @@ class XLNetPreTrainedModel(PreTrainedModel): class XLNetModel(XLNetPreTrainedModel): - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False, output_hidden_states=False): super(XLNetModel, self).__init__(config) self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.mem_len = config.mem_len self.reuse_len = config.reuse_len self.d_model = config.d_model @@ -597,8 +604,7 @@ class XLNetModel(XLNetPreTrainedModel): self.word_embedding = nn.Embedding(config.n_token, config.d_model) self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model)) - layer = XLNetLayer(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + layer = XLNetLayer(config, output_attentions=output_attentions) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)]) self.dropout = nn.Dropout(config.dropout) @@ -851,28 +857,39 @@ class XLNetModel(XLNetPreTrainedModel): if mems is None: mems = [None] * len(self.layer) + attentions = [] hidden_states = [] for i, layer_module in enumerate(self.layer): # cache new mems new_mems.append(self.cache_mem(output_h, mems[i])) + if self.output_hidden_states: + hidden_states.append((output_h, output_g) if output_g is not None else output_h) + + outputs = layer_module(output_h, output_g, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask, + r=pos_emb, seg_mat=seg_mat, mems=mems[i], target_mapping=target_mapping, + head_mask=head_mask) + output_h, output_g = outputs[:2] + if self.output_attentions: + attentions.append(outputs[2:]) + + # Add last hidden state + if self.output_hidden_states: hidden_states.append((output_h, output_g) if output_g is not None else output_h) - output_h, output_g = layer_module(output_h, output_g, - attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask, - r=pos_emb, seg_mat=seg_mat, - mems=mems[i], target_mapping=target_mapping, - head_mask=head_mask) - hidden_states.append((output_h, output_g) if output_g is not None else output_h) output = self.dropout(output_g if output_g is not None else output_h) - # We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) - output = output.permute(1, 0, 2).contiguous() - if output_g is not None: - hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs] - else: - hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states] + # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) + outputs = [output.permute(1, 0, 2).contiguous(), new_mems] + if self.output_hidden_states: + if output_g is not None: + hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs] + else: + hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states] + outputs.append(hidden_states) + if self.output_attentions: + outputs.append(attentions) - return output, hidden_states, new_mems + return outputs # outputs, new_mems, (hidden_states), (attentions) class XLNetLMHeadModel(XLNetPreTrainedModel): @@ -936,14 +953,16 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False, output_hidden_states=False): super(XLNetLMHeadModel, self).__init__(config) self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.attn_type = config.attn_type self.same_length = config.same_length self.transformer = XLNetModel(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + output_hidden_states=output_hidden_states) self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) # Tie weights @@ -989,27 +1008,24 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): summary_type: str, "last", "first", "mean", or "attn". The method to pool the input to get a vector representation. """ - output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, - mems, perm_mask, target_mapping, inp_q, head_mask) + transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, + mems, perm_mask, target_mapping, inp_q, head_mask) - logits = self.lm_loss(output) + logits = self.lm_loss(transformer_outputs[0]) + + outputs = [logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it if labels is not None: # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-1) loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - return loss, new_mems + outputs = [loss] + outputs - # if self.output_attentions: - # all_attentions, encoded_layers = encoded_layers - # if self.output_attentions: - return logits, new_mems - # return all_attentions, encoded_layers, pooled_output + return outputs # return (loss), logits, (mems), (hidden states), (attentions) class XLNetSequenceSummary(nn.Module): - def __init__(self, config, summary_type="last", use_proj=True, - output_attentions=False, keep_multihead_output=False): + def __init__(self, config, summary_type="last", use_proj=True): super(XLNetSequenceSummary, self).__init__() self.summary_type = summary_type if use_proj: @@ -1106,20 +1122,20 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ``` """ def __init__(self, config, summary_type="last", use_proj=True, num_labels=2, - output_attentions=False, keep_multihead_output=False): + output_attentions=False, output_hidden_states=False): super(XLNetForSequenceClassification, self).__init__(config) self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.attn_type = config.attn_type self.same_length = config.same_length self.summary_type = summary_type self.num_labels = num_labels self.transformer = XLNetModel(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + output_hidden_states=output_hidden_states) - self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type, - use_proj=use_proj, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type, use_proj=use_proj) self.logits_proj = nn.Linear(config.d_model, num_labels) self.apply(self.init_weights) @@ -1153,12 +1169,15 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): Only used during pretraining for two-stream attention. Set to None during finetuning. """ - output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, + transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, mems, perm_mask, target_mapping, inp_q, head_mask) + output = transformer_outputs[0] output = self.sequence_summary(output) logits = self.logits_proj(output) + outputs = [logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it + if labels is not None: if self.num_labels == 1: # We are doing regression @@ -1167,13 +1186,10 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): else: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss, new_mems + outputs = [loss] + outputs + + return outputs # return (loss), logits, (mems), (hidden states), (attentions) - # if self.output_attentions: - # all_attentions, encoded_layers = encoded_layers - # if self.output_attentions: - return logits, new_mems - # return all_attentions, encoded_layers, pooled_output class XLNetForQuestionAnswering(XLNetPreTrainedModel): """XLNet model for Question Answering (span extraction). @@ -1231,25 +1247,30 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): start_logits, end_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, output_attentions=False, keep_multihead_output=False): + def __init__(self, config, output_attentions=False, output_hidden_states=False): super(XLNetForQuestionAnswering, self).__init__(config) self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.transformer = XLNetModel(config, output_attentions=output_attentions, - keep_multihead_output=keep_multihead_output) + output_hidden_states=output_hidden_states) self.qa_outputs = nn.Linear(config.hidden_size, 2) self.apply(self.init_weights) def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None, start_positions=None, end_positions=None, head_mask=None): - output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, + transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, mems, perm_mask, target_mapping, inp_q, head_mask) - logits = self.qa_outputs(output) + logits = self.qa_outputs(transformer_outputs[0]) + start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) + outputs = [start_logits, end_logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it + if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: @@ -1265,7 +1286,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 - return total_loss - elif self.output_attentions: - return all_attentions, start_logits, end_logits - return start_logits, end_logits + outputs = [total_loss] + outputs + + return outputs # return (loss), logits, (mems), (hidden states), (attentions)