From f380bf2b612e6030ef8bc8904b287d274f035e29 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Sat, 29 Jan 2022 16:08:35 +0100 Subject: [PATCH] Fix the inconsistency of loss calculation between PT/TF XLNetLMHeadModel (#15298) * Fix the inconsistency of loss calculation between PT/TF XLNetLMHeadModel * overwrite test_loss_computation Co-authored-by: ydshieh --- .../models/xlnet/modeling_tf_xlnet.py | 5 +- tests/test_modeling_tf_xlnet.py | 64 +++++++++++++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py index 0e3b9a8b97..d680427d7a 100644 --- a/src/transformers/models/xlnet/modeling_tf_xlnet.py +++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py @@ -1390,10 +1390,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): loss = None if inputs["labels"] is not None: - # shift labels to the left and cut last logit token - logits = logits[:, :-1] - labels = inputs["labels"][:, 1:] - loss = self.hf_compute_loss(labels, logits) + loss = self.hf_compute_loss(inputs["labels"], logits) if not inputs["return_dict"]: output = (logits,) + transformer_outputs[1:] diff --git a/tests/test_modeling_tf_xlnet.py b/tests/test_modeling_tf_xlnet.py index 51fba4575f..1455b1ee13 100644 --- a/tests/test_modeling_tf_xlnet.py +++ b/tests/test_modeling_tf_xlnet.py @@ -14,6 +14,7 @@ # limitations under the License. +import inspect import random import unittest @@ -391,6 +392,69 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): model = TFXLNetModel.from_pretrained(model_name) self.assertIsNotNone(model) + # overwrite since `TFXLNetLMHeadModel` doesn't cut logits/labels + def test_loss_computation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + if getattr(model, "hf_compute_loss", None): + # The number of elements in the loss should be the same as the number of elements in the label + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + added_label = prepared_for_class[ + sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0] + ] + loss_size = tf.size(added_label) + + # `TFXLNetLMHeadModel` doesn't cut logits/labels + # if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING): + # # if loss is causal lm loss, labels are shift, so that one label per batch + # # is cut + # loss_size = loss_size - self.model_tester.batch_size + + # Test that model correctly compute the loss with kwargs + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + input_name = "input_ids" if "input_ids" in prepared_for_class else "pixel_values" + input_ids = prepared_for_class.pop(input_name) + + loss = model(input_ids, **prepared_for_class)[0] + self.assertEqual(loss.shape, [loss_size]) + + # Test that model correctly compute the loss with a dict + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + loss = model(prepared_for_class)[0] + self.assertEqual(loss.shape, [loss_size]) + + # Test that model correctly compute the loss with a tuple + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + + # Get keys that were added with the _prepare_for_class function + label_keys = prepared_for_class.keys() - inputs_dict.keys() + signature = inspect.signature(model.call).parameters + signature_names = list(signature.keys()) + + # Create a dictionary holding the location of the tensors in the tuple + tuple_index_mapping = {0: input_name} + for label_key in label_keys: + label_key_index = signature_names.index(label_key) + tuple_index_mapping[label_key_index] = label_key + sorted_tuple_index_mapping = sorted(tuple_index_mapping.items()) + # Initialize a list with their default values, update the values and convert to a tuple + list_input = [] + + for name in signature_names: + if name != "kwargs": + list_input.append(signature[name].default) + + for index, value in sorted_tuple_index_mapping: + list_input[index] = prepared_for_class[value] + + tuple_input = tuple(list_input) + + # Send to model + loss = model(tuple_input[:-1])[0] + + self.assertEqual(loss.shape, [loss_size]) + @require_tf class TFXLNetModelLanguageGenerationTest(unittest.TestCase):