Fix _merge_input_ids_with_image_features for llava model (#28333)

* fix `_merge_input_ids_with_image_features` for llava model

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* adress comments

* style and tests

* ooops

* test the backward too

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update tests/models/vipllava/test_modeling_vipllava.py

* style and quality

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Victor SANH
2024-01-10 02:33:33 -05:00
committed by GitHub
parent 976189a6df
commit 0f2f0c634f
4 changed files with 106 additions and 14 deletions

View File

@@ -26,7 +26,7 @@ from transformers import (
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -332,3 +332,41 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)
@slow
@require_torch_gpu
def test_llava_merge_inputs_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to(torch_device)
# Simulate some user inputs
pixel_values = torch.randn(
(2, 3, 336, 336),
dtype=torch.float,
device=torch_device,
)
input_ids = torch.tensor(
[
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
],
dtype=torch.long,
device=torch_device,
)
attention_mask = torch.tensor(
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.long,
device=torch_device,
)
# Make sure that the loss is properly computed
loss = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
).loss
loss.backward()