From 90e2df5d5544e3624ce711a28717204b7779c2d7 Mon Sep 17 00:00:00 2001 From: Sukriti Sharma Date: Fri, 21 Mar 2025 07:47:59 -0600 Subject: [PATCH] fix: loss computation after embeddings resize - mllama (#36840) * move loss to generation class Signed-off-by: Sukriti-Sharma4 * code cleanup Signed-off-by: Sukriti-Sharma4 * test for resize and loss computation Signed-off-by: Sukriti-Sharma4 * fix tests Signed-off-by: Sukriti-Sharma4 * fix:test for resize and loss Signed-off-by: Sukriti-Sharma4 * fix resize embedding mllama test Signed-off-by: Sukriti-Sharma4 * review changes Signed-off-by: Sukriti-Sharma4 --------- Signed-off-by: Sukriti-Sharma4 --- .../models/mllama/modeling_mllama.py | 21 +++++++++++++++++-- tests/models/mllama/test_modeling_mllama.py | 18 ++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 1981f4287b..0818d90a57 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -2056,6 +2056,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -2157,15 +2158,31 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): past_key_values=past_key_values, use_cache=use_cache, inputs_embeds=inputs_embeds, - labels=labels, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, + **loss_kwargs, ) - return outputs + # Temporary fix to calculate the loss in main class, as the model's vocab size may be resized + loss = None + logits = outputs[0] + + if labels is not None: + loss = self.loss_function(logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs) + + if not return_dict: + return (loss,) + outputs if loss is not None else outputs + + return CausalLMOutputWithPast( + loss=loss, + logits=outputs.logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) def prepare_inputs_for_generation( self, diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 023dd8ea2b..ae28bd6697 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -321,6 +321,24 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] torch.testing.assert_close(out_embeds, out_ids) + def test_resize_embeddings_results_in_successful_loss(self): + # resizing embeddings should result in successful loss computation + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model_vocab_size = config.get_text_config().vocab_size + inputs = self._prepare_for_class(inputs, model_class, return_labels=True) + # Resize embeddings and call forward + model.resize_token_embeddings(model_vocab_size + 10) + output = model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + labels=inputs["labels"], + return_dict=True, + ) + self.assertTrue("loss" in output) + def _check_attentions_for_generate( self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ):