Fixes Loss for TransfoXL when using Trainer API v2 (#16140)
* fix(transfo_xl): Fixes TransfoXL support when using Trainer. * fix(tests): Uses losses_1 and losses_2 pattern with TransfoXL test. * fix(transfo_xl): Adds requested changes to allow for backward compatibility. fix(transfo_xl): Adds requested changes to allow for backward compatibility. fix(transfo_xl): Fixes code styling. * Backward compatibility * Update src/transformers/models/transfo_xl/modeling_transfo_xl.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Gustavo de Rosa <gth.rosa@uol.com.br> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -132,30 +132,90 @@ class TransfoXLModelTester:
|
||||
outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"])
|
||||
|
||||
outputs = {
|
||||
"loss_1": outputs1["losses"],
|
||||
"loss_1": outputs1["loss"],
|
||||
"losses_1": outputs1["losses"],
|
||||
"mems_1": outputs1["mems"],
|
||||
"lm_logits_1": lm_logits_1,
|
||||
"loss_2": outputs2["losses"],
|
||||
"loss_2": outputs2["loss"],
|
||||
"losses_2": outputs2["losses"],
|
||||
"mems_2": outputs2["mems"],
|
||||
"lm_logits_2": lm_logits_2,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def check_transfo_xl_lm_head_output(self, result):
|
||||
self.parent.assertEqual(result["loss_1"].shape, (self.batch_size, self.seq_length - 1))
|
||||
self.parent.assertEqual(result["loss_1"].shape, ())
|
||||
self.parent.assertEqual(result["losses_1"].shape, (self.batch_size, self.seq_length - 1))
|
||||
self.parent.assertEqual(result["lm_logits_1"].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
[mem.shape for mem in result["mems_1"]],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
self.parent.assertEqual(result["loss_2"].shape, (self.batch_size, self.seq_length - 1))
|
||||
self.parent.assertEqual(result["loss_2"].shape, ())
|
||||
self.parent.assertEqual(result["losses_2"].shape, (self.batch_size, self.seq_length - 1))
|
||||
self.parent.assertEqual(result["lm_logits_2"].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
[mem.shape for mem in result["mems_2"]],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_transfo_xl_lm_head_trainer_compatible_tuple(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
config.trainer_compatible = True
|
||||
model = TransfoXLLMHeadModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
lm_logits_1 = model(input_ids_1, return_dict=False)[0]
|
||||
outputs1 = model(input_ids_1, labels=lm_labels, return_dict=False)
|
||||
loss_1, _, losses_1, mems_1 = outputs1[:4]
|
||||
lm_logits_2 = model(input_ids_2, mems=mems_1, return_dict=False)[0]
|
||||
outputs2 = model(input_ids_2, labels=lm_labels, mems=mems_1, return_dict=False)
|
||||
loss_2, _, losses_2, mems_2 = outputs2[:4]
|
||||
|
||||
outputs = {
|
||||
"losses_1": losses_1,
|
||||
"mems_1": mems_1,
|
||||
"lm_logits_1": lm_logits_1,
|
||||
"loss_1": loss_1,
|
||||
"losses_2": losses_2,
|
||||
"mems_2": mems_2,
|
||||
"lm_logits_2": lm_logits_2,
|
||||
"loss_2": loss_2,
|
||||
}
|
||||
|
||||
config.trainer_compatible = None
|
||||
return outputs
|
||||
|
||||
def create_transfo_xl_lm_head_trainer_incompatible_tuple(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
config.trainer_compatible = False
|
||||
model = TransfoXLLMHeadModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
lm_logits_1 = model(input_ids_1, return_dict=False)[0]
|
||||
outputs1 = model(input_ids_1, labels=lm_labels, return_dict=False)
|
||||
losses_1, _, mems_1 = outputs1[:3]
|
||||
loss_1 = outputs1[-1]
|
||||
lm_logits_2 = model(input_ids_2, mems=mems_1, return_dict=False)[0]
|
||||
outputs2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
|
||||
losses_2, _, mems_2 = outputs2[:3]
|
||||
loss_2 = outputs2[-1]
|
||||
|
||||
outputs = {
|
||||
"losses_1": losses_1,
|
||||
"mems_1": mems_1,
|
||||
"lm_logits_1": lm_logits_1,
|
||||
"loss_1": loss_1,
|
||||
"losses_2": losses_2,
|
||||
"mems_2": mems_2,
|
||||
"lm_logits_2": lm_logits_2,
|
||||
"loss_2": loss_2,
|
||||
}
|
||||
|
||||
config.trainer_compatible = None
|
||||
return outputs
|
||||
|
||||
def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = TransfoXLForSequenceClassification(config)
|
||||
@@ -220,9 +280,16 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
||||
def test_transfo_xl_lm_head(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
|
||||
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||
|
||||
output_result = self.model_tester.create_transfo_xl_lm_head_trainer_compatible_tuple(*config_and_inputs)
|
||||
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||
|
||||
output_result = self.model_tester.create_transfo_xl_lm_head_trainer_incompatible_tuple(*config_and_inputs)
|
||||
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||
|
||||
def test_transfo_xl_sequence_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs)
|
||||
@@ -232,10 +299,8 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
||||
return
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@unittest.skip(
|
||||
reason="Transfo-XL does not work with data parallel (DP) because of a bug in PyTorch: https://github.com/pytorch/pytorch/issues/36035"
|
||||
)
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
# Opt-out of this test.
|
||||
pass
|
||||
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user