Fixes Loss for TransfoXL when using Trainer API v2 (#16140)
* fix(transfo_xl): Fixes TransfoXL support when using Trainer. * fix(tests): Uses losses_1 and losses_2 pattern with TransfoXL test. * fix(transfo_xl): Adds requested changes to allow for backward compatibility. fix(transfo_xl): Adds requested changes to allow for backward compatibility. fix(transfo_xl): Fixes code styling. * Backward compatibility * Update src/transformers/models/transfo_xl/modeling_transfo_xl.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Gustavo de Rosa <gth.rosa@uol.com.br> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@
|
|||||||
PyTorch Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl. In particular
|
PyTorch Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl. In particular
|
||||||
https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
||||||
"""
|
"""
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@@ -692,6 +693,8 @@ class TransfoXLLMHeadModelOutput(ModelOutput):
|
|||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
|
loss (`torch.FloatTensor` of shape `()`, *optional*, returned when `labels` is provided)
|
||||||
|
Reduced language modeling loss.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
losses: Optional[torch.FloatTensor] = None
|
losses: Optional[torch.FloatTensor] = None
|
||||||
@@ -699,6 +702,7 @@ class TransfoXLLMHeadModelOutput(ModelOutput):
|
|||||||
mems: List[torch.FloatTensor] = None
|
mems: List[torch.FloatTensor] = None
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logits(self):
|
def logits(self):
|
||||||
@@ -1011,6 +1015,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.transformer = TransfoXLModel(config)
|
self.transformer = TransfoXLModel(config)
|
||||||
self.sample_softmax = config.sample_softmax
|
self.sample_softmax = config.sample_softmax
|
||||||
|
self.trainer_compatible = getattr(config, "trainer_compatible", False)
|
||||||
|
|
||||||
|
if not self.trainer_compatible:
|
||||||
|
warnings.warn(
|
||||||
|
"The output of TransfoXL will be updated in v5 to support a single loss as first argument. In order"
|
||||||
|
"to use that updated output, please specify `trainer_compatible=True` as your configuration attribute.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
self.sample_softmax <= 0
|
self.sample_softmax <= 0
|
||||||
@@ -1095,17 +1107,38 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
last_hidden = transformer_outputs[0]
|
last_hidden = transformer_outputs[0]
|
||||||
pred_hid = last_hidden[:, -tgt_len:]
|
pred_hid = last_hidden[:, -tgt_len:]
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
# Prevents all labels being -100 and throwing an error
|
||||||
|
# when backwarding the loss
|
||||||
|
miss_valid_label = labels[0, 1:].sum() == (labels.size(1) - 1) * -100
|
||||||
|
if miss_valid_label:
|
||||||
|
# Sets an <EOS> token, just to prevent loss from being NaN
|
||||||
|
labels[0, 1] = self.config.eos_token_id
|
||||||
|
|
||||||
softmax_output = self.crit(pred_hid, labels)
|
softmax_output = self.crit(pred_hid, labels)
|
||||||
prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else ()
|
prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else ()
|
||||||
loss = softmax_output.view(bsz, tgt_len - 1) if labels is not None else None
|
|
||||||
|
if labels is not None:
|
||||||
|
losses = softmax_output.view(bsz, tgt_len - 1)
|
||||||
|
# Avoids from incorporating padding (-100) tokens into loss value
|
||||||
|
loss = losses[losses != 0].mean()
|
||||||
|
else:
|
||||||
|
losses, loss = None, None
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (prediction_scores,) + transformer_outputs[1:]
|
if self.trainer_compatible:
|
||||||
|
output = (prediction_scores, losses) if losses is not None else (prediction_scores,)
|
||||||
|
output += transformer_outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
else:
|
||||||
|
output = (prediction_scores, *transformer_outputs[1:])
|
||||||
|
output = ((losses,) + output) if losses is not None else output
|
||||||
|
return (output + (loss,)) if loss is not None else output
|
||||||
|
|
||||||
return TransfoXLLMHeadModelOutput(
|
return TransfoXLLMHeadModelOutput(
|
||||||
losses=loss,
|
loss=loss,
|
||||||
prediction_scores=prediction_scores,
|
prediction_scores=prediction_scores,
|
||||||
|
losses=losses,
|
||||||
mems=transformer_outputs.mems,
|
mems=transformer_outputs.mems,
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
|
|||||||
@@ -132,30 +132,90 @@ class TransfoXLModelTester:
|
|||||||
outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"])
|
outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"])
|
||||||
|
|
||||||
outputs = {
|
outputs = {
|
||||||
"loss_1": outputs1["losses"],
|
"loss_1": outputs1["loss"],
|
||||||
|
"losses_1": outputs1["losses"],
|
||||||
"mems_1": outputs1["mems"],
|
"mems_1": outputs1["mems"],
|
||||||
"lm_logits_1": lm_logits_1,
|
"lm_logits_1": lm_logits_1,
|
||||||
"loss_2": outputs2["losses"],
|
"loss_2": outputs2["loss"],
|
||||||
|
"losses_2": outputs2["losses"],
|
||||||
"mems_2": outputs2["mems"],
|
"mems_2": outputs2["mems"],
|
||||||
"lm_logits_2": lm_logits_2,
|
"lm_logits_2": lm_logits_2,
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def check_transfo_xl_lm_head_output(self, result):
|
def check_transfo_xl_lm_head_output(self, result):
|
||||||
self.parent.assertEqual(result["loss_1"].shape, (self.batch_size, self.seq_length - 1))
|
self.parent.assertEqual(result["loss_1"].shape, ())
|
||||||
|
self.parent.assertEqual(result["losses_1"].shape, (self.batch_size, self.seq_length - 1))
|
||||||
self.parent.assertEqual(result["lm_logits_1"].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result["lm_logits_1"].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
[mem.shape for mem in result["mems_1"]],
|
[mem.shape for mem in result["mems_1"]],
|
||||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parent.assertEqual(result["loss_2"].shape, (self.batch_size, self.seq_length - 1))
|
self.parent.assertEqual(result["loss_2"].shape, ())
|
||||||
|
self.parent.assertEqual(result["losses_2"].shape, (self.batch_size, self.seq_length - 1))
|
||||||
self.parent.assertEqual(result["lm_logits_2"].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result["lm_logits_2"].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
[mem.shape for mem in result["mems_2"]],
|
[mem.shape for mem in result["mems_2"]],
|
||||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_transfo_xl_lm_head_trainer_compatible_tuple(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||||
|
config.trainer_compatible = True
|
||||||
|
model = TransfoXLLMHeadModel(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
lm_logits_1 = model(input_ids_1, return_dict=False)[0]
|
||||||
|
outputs1 = model(input_ids_1, labels=lm_labels, return_dict=False)
|
||||||
|
loss_1, _, losses_1, mems_1 = outputs1[:4]
|
||||||
|
lm_logits_2 = model(input_ids_2, mems=mems_1, return_dict=False)[0]
|
||||||
|
outputs2 = model(input_ids_2, labels=lm_labels, mems=mems_1, return_dict=False)
|
||||||
|
loss_2, _, losses_2, mems_2 = outputs2[:4]
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
"losses_1": losses_1,
|
||||||
|
"mems_1": mems_1,
|
||||||
|
"lm_logits_1": lm_logits_1,
|
||||||
|
"loss_1": loss_1,
|
||||||
|
"losses_2": losses_2,
|
||||||
|
"mems_2": mems_2,
|
||||||
|
"lm_logits_2": lm_logits_2,
|
||||||
|
"loss_2": loss_2,
|
||||||
|
}
|
||||||
|
|
||||||
|
config.trainer_compatible = None
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def create_transfo_xl_lm_head_trainer_incompatible_tuple(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||||
|
config.trainer_compatible = False
|
||||||
|
model = TransfoXLLMHeadModel(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
lm_logits_1 = model(input_ids_1, return_dict=False)[0]
|
||||||
|
outputs1 = model(input_ids_1, labels=lm_labels, return_dict=False)
|
||||||
|
losses_1, _, mems_1 = outputs1[:3]
|
||||||
|
loss_1 = outputs1[-1]
|
||||||
|
lm_logits_2 = model(input_ids_2, mems=mems_1, return_dict=False)[0]
|
||||||
|
outputs2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
|
||||||
|
losses_2, _, mems_2 = outputs2[:3]
|
||||||
|
loss_2 = outputs2[-1]
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
"losses_1": losses_1,
|
||||||
|
"mems_1": mems_1,
|
||||||
|
"lm_logits_1": lm_logits_1,
|
||||||
|
"loss_1": loss_1,
|
||||||
|
"losses_2": losses_2,
|
||||||
|
"mems_2": mems_2,
|
||||||
|
"lm_logits_2": lm_logits_2,
|
||||||
|
"loss_2": loss_2,
|
||||||
|
}
|
||||||
|
|
||||||
|
config.trainer_compatible = None
|
||||||
|
return outputs
|
||||||
|
|
||||||
def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels):
|
def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
model = TransfoXLForSequenceClassification(config)
|
model = TransfoXLForSequenceClassification(config)
|
||||||
@@ -220,9 +280,16 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
|||||||
def test_transfo_xl_lm_head(self):
|
def test_transfo_xl_lm_head(self):
|
||||||
self.model_tester.set_seed()
|
self.model_tester.set_seed()
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
|
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
|
||||||
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||||
|
|
||||||
|
output_result = self.model_tester.create_transfo_xl_lm_head_trainer_compatible_tuple(*config_and_inputs)
|
||||||
|
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||||
|
|
||||||
|
output_result = self.model_tester.create_transfo_xl_lm_head_trainer_incompatible_tuple(*config_and_inputs)
|
||||||
|
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||||
|
|
||||||
def test_transfo_xl_sequence_classification_model(self):
|
def test_transfo_xl_sequence_classification_model(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs)
|
self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs)
|
||||||
@@ -232,10 +299,8 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
|||||||
return
|
return
|
||||||
|
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@unittest.skip(
|
|
||||||
reason="Transfo-XL does not work with data parallel (DP) because of a bug in PyTorch: https://github.com/pytorch/pytorch/issues/36035"
|
|
||||||
)
|
|
||||||
def test_multi_gpu_data_parallel_forward(self):
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
|
# Opt-out of this test.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
Reference in New Issue
Block a user