From e60491adc945f112d2fd410652f43d7149f0122c Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 25 Apr 2024 20:28:51 +0500 Subject: [PATCH] Fix Llava for 0-embeddings (#30473) --- .../models/llava/modeling_llava.py | 7 +++-- .../models/llava_next/modeling_llava_next.py | 7 +++-- .../models/vipllava/modeling_vipllava.py | 7 +++-- .../llava_next/test_modeling_llava_next.py | 26 +++++++++++++++++++ 4 files changed, 41 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 4cf5d98f77..5a6d49752d 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -327,8 +327,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): if labels is not None: final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling - image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 155d9e3e6a..085999933b 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -403,8 +403,11 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): if labels is not None: final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling - image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 1b20353410..aaffc19bd5 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -331,8 +331,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): if labels is not None: final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling - image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 1c7e320090..3656bb6505 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -459,3 +459,29 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_unk_token(self): + # related to (#29835) + model = LlavaNextForConditionalGeneration.from_pretrained( + "llava-hf/llava-v1.6-mistral-7b-hf", + load_in_4bit=True, + ) + + prompt_with_unk = "[INST] \nWhat is shown in this image? [/INST]" + inputs = self.processor(prompt_with_unk, self.image, return_tensors="pt") + + # verify single forward pass + inputs = inputs.to(torch_device) + with torch.no_grad(): + output = model(**inputs) + + # verify generation + output = model.generate(**inputs, max_new_tokens=40) + EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart' # fmt: skip + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + )