From 29e7a1e1834f331a4916853ecd58549ed78235d6 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 22 Dec 2023 17:47:38 +0100 Subject: [PATCH] [`Llava`] Fix llava index errors (#28032) * fix llava index errors * forward contrib credits from original implementation and fix * better fix * final fixes and fix all tests * fix * fix nit * fix tests * add regression tests --------- Co-authored-by: gullalc --- .../models/llava/modeling_llava.py | 7 +- .../models/vipllava/modeling_vipllava.py | 9 ++- tests/models/llava/test_modeling_llava.py | 79 +++++++++++++++---- 3 files changed, 75 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 2e5062bb17..54f700bbd7 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -431,8 +431,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0] - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0) + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + # Get the target length target_seqlen = first_layer_past_key_value.shape[-1] + 1 diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 2037d0527c..8aa74857af 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -430,10 +430,13 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0] - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0) + first_layer_past_key_value = past_key_values[0][0][:, 0, :, :] + + # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0) + # Get the target length - target_seqlen = first_layer_past_key_value.shape[-1] + 1 + target_seqlen = first_layer_past_key_value.shape[-2] + 1 extended_attention_mask = torch.ones( (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 23daaecc4e..3037e972a3 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -247,6 +247,53 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): model_id = "llava-hf/llava-1.5-7b-hf" model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + prompts = [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: \nWhat is this?\nASSISTANT:", + ] + image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip + + self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_batch(self): + # Let' s make sure we test the preprocessing to replace what is used + model = LlavaForConditionalGeneration.from_pretrained("llava-hf/bakLlava-v1-hf", load_in_4bit=True) + # The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!. + prompts = [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: \nWhat is this?\nASSISTANT:", + ] + image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = self.processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip + self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_llama_batched_regression(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "llava-hf/llava-1.5-7b-hf" + + # Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before) + model = LlavaForConditionalGeneration.from_pretrained( + "llava-hf/llava-1.5-7b-hf", load_in_4bit=True, attn_implementation="eager" + ) processor = AutoProcessor.from_pretrained(model_id, pad_token="") prompts = [ @@ -260,26 +307,28 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: the water is calm and clear\n\nThe image shows a wooden pier on a lake, with a', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) @slow @require_bitsandbytes - def test_small_model_integration_test_batch(self): - # Let' s make sure we test the preprocessing to replace what is used - model = LlavaForConditionalGeneration.from_pretrained("llava-hf/bakLlava-v1-hf", load_in_4bit=True) - # The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!. - prompts = [ - "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", - "USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT:", - ] - image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) - image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + def test_llava_index_error_bug(self): + # This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore + # Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for + # more details + model_id = "llava-hf/llava-1.5-7b-hf" + model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) - inputs = self.processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True) + processor = AutoProcessor.from_pretrained(model_id) - output = model.generate(**inputs, max_new_tokens=20) + # Simulate a super long prompt + user_prompt = "Describe the image:?\n" * 200 + prompt = f"USER: \n{user_prompt}ASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - EXPECTED_DECODED_TEXT = ['\nUSER: What are the things I should be cautious about when I visit this place? What should I bring with me\nASSISTANT: When visiting this place, bring a camera, Park rules may apply, Per person, Sunrise over', '\nUSER: What is this?\nASSISTANT: Two cats lying on a bed!\nUSER: And this? a dock on a lake with two cats on it (Photo credit: Jocelyn R.'] # fmt: skip - self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + + # Make sure that `generate` works + _ = model.generate(**inputs, max_new_tokens=20)