From 0f2f0c634ff3d9e69212ca6581d043c945f56fad Mon Sep 17 00:00:00 2001 From: Victor SANH Date: Wed, 10 Jan 2024 02:33:33 -0500 Subject: [PATCH] Fix `_merge_input_ids_with_image_features` for llava model (#28333) * fix `_merge_input_ids_with_image_features` for llava model * Update src/transformers/models/llava/modeling_llava.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * adress comments * style and tests * ooops * test the backward too * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update tests/models/vipllava/test_modeling_vipllava.py * style and quality --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- .../models/llava/modeling_llava.py | 20 +++++++--- .../models/vipllava/modeling_vipllava.py | 20 +++++++--- tests/models/llava/test_modeling_llava.py | 40 ++++++++++++++++++- .../models/vipllava/test_modeling_vipllava.py | 40 ++++++++++++++++++- 4 files changed, 106 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index e8f925938a..bd205e0fc9 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -276,9 +276,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): self.vocab_size = model_embeds.num_embeddings return model_embeds - def _merge_input_ids_with_image_features( - self, image_features, inputs_embeds, input_ids, attention_mask, position_ids - ): + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): num_images, num_image_patches, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) @@ -307,6 +305,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): final_attention_mask = torch.zeros( batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = inputs_embeds.device @@ -321,6 +323,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling image_to_overwrite = torch.all(final_embedding == 0, dim=-1) @@ -335,7 +339,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - return final_embedding, final_attention_mask, position_ids + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -420,8 +428,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ) image_features = self.multi_modal_projector(selected_image_feature) - inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, position_ids + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels ) if labels is None: labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index c1aa948580..748c64b22e 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -284,9 +284,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): self.vocab_size = model_embeds.num_embeddings return model_embeds - def _merge_input_ids_with_image_features( - self, image_features, inputs_embeds, input_ids, attention_mask, position_ids - ): + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): num_images, num_image_patches, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) @@ -315,6 +313,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): final_attention_mask = torch.zeros( batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = inputs_embeds.device @@ -329,6 +331,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling image_to_overwrite = torch.all(final_embedding == 0, dim=-1) @@ -343,7 +347,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - return final_embedding, final_attention_mask, position_ids + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -419,8 +427,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): image_features = torch.cat(image_features, dim=-1) image_features = self.multi_modal_projector(image_features) - inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, position_ids + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels ) if labels is None: labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 3037e972a3..2ece22f12a 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -26,7 +26,7 @@ from transformers import ( is_torch_available, is_vision_available, ) -from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device +from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -332,3 +332,41 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): # Make sure that `generate` works _ = model.generate(**inputs, max_new_tokens=20) + + @slow + @require_torch_gpu + def test_llava_merge_inputs_error_bug(self): + # This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore + model_id = "llava-hf/llava-1.5-7b-hf" + model = LlavaForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ).to(torch_device) + + # Simulate some user inputs + pixel_values = torch.randn( + (2, 3, 336, 336), + dtype=torch.float, + device=torch_device, + ) + input_ids = torch.tensor( + [ + [32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900], + [1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900], + ], + dtype=torch.long, + device=torch_device, + ) + attention_mask = torch.tensor( + [[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]], + dtype=torch.long, + device=torch_device, + ) + + # Make sure that the loss is properly computed + loss = model( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + labels=input_ids, + ).loss + loss.backward() diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index e09527343e..ff84f71784 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -26,7 +26,7 @@ from transformers import ( is_torch_available, is_vision_available, ) -from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device +from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -214,3 +214,41 @@ class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): EXPECTED_OUTPUT = "USER: \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on" self.assertEqual(processor.decode(outputs[0], skip_special_tokens=True), EXPECTED_OUTPUT) + + @slow + @require_torch_gpu + def test_vipllava_merge_inputs_error_bug(self): + # This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore + model_id = "llava-hf/vip-llava-7b-hf" + model = VipLlavaForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ).to(torch_device) + + # Simulate some user inputs + pixel_values = torch.randn( + (2, 3, 336, 336), + dtype=torch.float, + device=torch_device, + ) + input_ids = torch.tensor( + [ + [32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900], + [1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900], + ], + dtype=torch.long, + device=torch_device, + ) + attention_mask = torch.tensor( + [[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]], + dtype=torch.long, + device=torch_device, + ) + + # Make sure that the loss is properly computed + loss = model( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + labels=input_ids, + ).loss + loss.backward()