From 73f0a5d1f66b15080cc9976266eb87e7cf9ebe0d Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 17 Mar 2022 10:49:24 +0100 Subject: [PATCH] Fixes Loss for TransfoXL when using Trainer API v2 (#16140) * fix(transfo_xl): Fixes TransfoXL support when using Trainer. * fix(tests): Uses losses_1 and losses_2 pattern with TransfoXL test. * fix(transfo_xl): Adds requested changes to allow for backward compatibility. fix(transfo_xl): Adds requested changes to allow for backward compatibility. fix(transfo_xl): Fixes code styling. * Backward compatibility * Update src/transformers/models/transfo_xl/modeling_transfo_xl.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Gustavo de Rosa Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../models/transfo_xl/modeling_transfo_xl.py | 41 +++++++++- tests/transfo_xl/test_modeling_transfo_xl.py | 79 +++++++++++++++++-- 2 files changed, 109 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py index 6ba9903b9c..04fcf41b84 100644 --- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py @@ -17,6 +17,7 @@ PyTorch Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl. In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py """ +import warnings from dataclasses import dataclass from typing import List, Optional, Tuple @@ -692,6 +693,8 @@ class TransfoXLLMHeadModelOutput(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + loss (`torch.FloatTensor` of shape `()`, *optional*, returned when `labels` is provided) + Reduced language modeling loss. """ losses: Optional[torch.FloatTensor] = None @@ -699,6 +702,7 @@ class TransfoXLLMHeadModelOutput(ModelOutput): mems: List[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None + loss: Optional[torch.FloatTensor] = None @property def logits(self): @@ -1011,6 +1015,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): super().__init__(config) self.transformer = TransfoXLModel(config) self.sample_softmax = config.sample_softmax + self.trainer_compatible = getattr(config, "trainer_compatible", False) + + if not self.trainer_compatible: + warnings.warn( + "The output of TransfoXL will be updated in v5 to support a single loss as first argument. In order" + "to use that updated output, please specify `trainer_compatible=True` as your configuration attribute.", + DeprecationWarning, + ) assert ( self.sample_softmax <= 0 @@ -1095,17 +1107,38 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): last_hidden = transformer_outputs[0] pred_hid = last_hidden[:, -tgt_len:] + if labels is not None: + # Prevents all labels being -100 and throwing an error + # when backwarding the loss + miss_valid_label = labels[0, 1:].sum() == (labels.size(1) - 1) * -100 + if miss_valid_label: + # Sets an token, just to prevent loss from being NaN + labels[0, 1] = self.config.eos_token_id + softmax_output = self.crit(pred_hid, labels) prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else () - loss = softmax_output.view(bsz, tgt_len - 1) if labels is not None else None + + if labels is not None: + losses = softmax_output.view(bsz, tgt_len - 1) + # Avoids from incorporating padding (-100) tokens into loss value + loss = losses[losses != 0].mean() + else: + losses, loss = None, None if not return_dict: - output = (prediction_scores,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output + if self.trainer_compatible: + output = (prediction_scores, losses) if losses is not None else (prediction_scores,) + output += transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + else: + output = (prediction_scores, *transformer_outputs[1:]) + output = ((losses,) + output) if losses is not None else output + return (output + (loss,)) if loss is not None else output return TransfoXLLMHeadModelOutput( - losses=loss, + loss=loss, prediction_scores=prediction_scores, + losses=losses, mems=transformer_outputs.mems, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, diff --git a/tests/transfo_xl/test_modeling_transfo_xl.py b/tests/transfo_xl/test_modeling_transfo_xl.py index 51597f2338..12098c5185 100644 --- a/tests/transfo_xl/test_modeling_transfo_xl.py +++ b/tests/transfo_xl/test_modeling_transfo_xl.py @@ -132,30 +132,90 @@ class TransfoXLModelTester: outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"]) outputs = { - "loss_1": outputs1["losses"], + "loss_1": outputs1["loss"], + "losses_1": outputs1["losses"], "mems_1": outputs1["mems"], "lm_logits_1": lm_logits_1, - "loss_2": outputs2["losses"], + "loss_2": outputs2["loss"], + "losses_2": outputs2["losses"], "mems_2": outputs2["mems"], "lm_logits_2": lm_logits_2, } return outputs def check_transfo_xl_lm_head_output(self, result): - self.parent.assertEqual(result["loss_1"].shape, (self.batch_size, self.seq_length - 1)) + self.parent.assertEqual(result["loss_1"].shape, ()) + self.parent.assertEqual(result["losses_1"].shape, (self.batch_size, self.seq_length - 1)) self.parent.assertEqual(result["lm_logits_1"].shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertListEqual( [mem.shape for mem in result["mems_1"]], [(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers, ) - self.parent.assertEqual(result["loss_2"].shape, (self.batch_size, self.seq_length - 1)) + self.parent.assertEqual(result["loss_2"].shape, ()) + self.parent.assertEqual(result["losses_2"].shape, (self.batch_size, self.seq_length - 1)) self.parent.assertEqual(result["lm_logits_2"].shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertListEqual( [mem.shape for mem in result["mems_2"]], [(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers, ) + def create_transfo_xl_lm_head_trainer_compatible_tuple(self, config, input_ids_1, input_ids_2, lm_labels): + config.trainer_compatible = True + model = TransfoXLLMHeadModel(config) + model.to(torch_device) + model.eval() + + lm_logits_1 = model(input_ids_1, return_dict=False)[0] + outputs1 = model(input_ids_1, labels=lm_labels, return_dict=False) + loss_1, _, losses_1, mems_1 = outputs1[:4] + lm_logits_2 = model(input_ids_2, mems=mems_1, return_dict=False)[0] + outputs2 = model(input_ids_2, labels=lm_labels, mems=mems_1, return_dict=False) + loss_2, _, losses_2, mems_2 = outputs2[:4] + + outputs = { + "losses_1": losses_1, + "mems_1": mems_1, + "lm_logits_1": lm_logits_1, + "loss_1": loss_1, + "losses_2": losses_2, + "mems_2": mems_2, + "lm_logits_2": lm_logits_2, + "loss_2": loss_2, + } + + config.trainer_compatible = None + return outputs + + def create_transfo_xl_lm_head_trainer_incompatible_tuple(self, config, input_ids_1, input_ids_2, lm_labels): + config.trainer_compatible = False + model = TransfoXLLMHeadModel(config) + model.to(torch_device) + model.eval() + + lm_logits_1 = model(input_ids_1, return_dict=False)[0] + outputs1 = model(input_ids_1, labels=lm_labels, return_dict=False) + losses_1, _, mems_1 = outputs1[:3] + loss_1 = outputs1[-1] + lm_logits_2 = model(input_ids_2, mems=mems_1, return_dict=False)[0] + outputs2 = model(input_ids_2, labels=lm_labels, mems=mems_1) + losses_2, _, mems_2 = outputs2[:3] + loss_2 = outputs2[-1] + + outputs = { + "losses_1": losses_1, + "mems_1": mems_1, + "lm_logits_1": lm_logits_1, + "loss_1": loss_1, + "losses_2": losses_2, + "mems_2": mems_2, + "lm_logits_2": lm_logits_2, + "loss_2": loss_2, + } + + config.trainer_compatible = None + return outputs + def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels): config.num_labels = self.num_labels model = TransfoXLForSequenceClassification(config) @@ -220,9 +280,16 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC def test_transfo_xl_lm_head(self): self.model_tester.set_seed() config_and_inputs = self.model_tester.prepare_config_and_inputs() + output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs) self.model_tester.check_transfo_xl_lm_head_output(output_result) + output_result = self.model_tester.create_transfo_xl_lm_head_trainer_compatible_tuple(*config_and_inputs) + self.model_tester.check_transfo_xl_lm_head_output(output_result) + + output_result = self.model_tester.create_transfo_xl_lm_head_trainer_incompatible_tuple(*config_and_inputs) + self.model_tester.check_transfo_xl_lm_head_output(output_result) + def test_transfo_xl_sequence_classification_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs) @@ -232,10 +299,8 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC return @require_torch_multi_gpu - @unittest.skip( - reason="Transfo-XL does not work with data parallel (DP) because of a bug in PyTorch: https://github.com/pytorch/pytorch/issues/36035" - ) def test_multi_gpu_data_parallel_forward(self): + # Opt-out of this test. pass @slow