Fix _merge_input_ids_with_image_features for llava model (#28333)
* fix `_merge_input_ids_with_image_features` for llava model * Update src/transformers/models/llava/modeling_llava.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * adress comments * style and tests * ooops * test the backward too * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update tests/models/vipllava/test_modeling_vipllava.py * style and quality --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -276,9 +276,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
self.vocab_size = model_embeds.num_embeddings
|
self.vocab_size = model_embeds.num_embeddings
|
||||||
return model_embeds
|
return model_embeds
|
||||||
|
|
||||||
def _merge_input_ids_with_image_features(
|
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||||
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids
|
|
||||||
):
|
|
||||||
num_images, num_image_patches, embed_dim = image_features.shape
|
num_images, num_image_patches, embed_dim = image_features.shape
|
||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||||
@@ -307,6 +305,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
final_attention_mask = torch.zeros(
|
final_attention_mask = torch.zeros(
|
||||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
if labels is not None:
|
||||||
|
final_labels = torch.full(
|
||||||
|
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||||
|
)
|
||||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||||
# set the corresponding tensors into their correct target device.
|
# set the corresponding tensors into their correct target device.
|
||||||
target_device = inputs_embeds.device
|
target_device = inputs_embeds.device
|
||||||
@@ -321,6 +323,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||||
|
if labels is not None:
|
||||||
|
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||||
|
|
||||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
||||||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
||||||
@@ -335,7 +339,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||||
final_attention_mask |= image_to_overwrite
|
final_attention_mask |= image_to_overwrite
|
||||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||||
return final_embedding, final_attention_mask, position_ids
|
|
||||||
|
if labels is None:
|
||||||
|
final_labels = None
|
||||||
|
|
||||||
|
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
@@ -420,8 +428,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
image_features = self.multi_modal_projector(selected_image_feature)
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||||
image_features, inputs_embeds, input_ids, attention_mask, position_ids
|
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
if labels is None:
|
if labels is None:
|
||||||
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
||||||
|
|||||||
@@ -284,9 +284,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
self.vocab_size = model_embeds.num_embeddings
|
self.vocab_size = model_embeds.num_embeddings
|
||||||
return model_embeds
|
return model_embeds
|
||||||
|
|
||||||
def _merge_input_ids_with_image_features(
|
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||||
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids
|
|
||||||
):
|
|
||||||
num_images, num_image_patches, embed_dim = image_features.shape
|
num_images, num_image_patches, embed_dim = image_features.shape
|
||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||||
@@ -315,6 +313,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
final_attention_mask = torch.zeros(
|
final_attention_mask = torch.zeros(
|
||||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
if labels is not None:
|
||||||
|
final_labels = torch.full(
|
||||||
|
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||||
|
)
|
||||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||||
# set the corresponding tensors into their correct target device.
|
# set the corresponding tensors into their correct target device.
|
||||||
target_device = inputs_embeds.device
|
target_device = inputs_embeds.device
|
||||||
@@ -329,6 +331,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||||
|
if labels is not None:
|
||||||
|
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||||
|
|
||||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
||||||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
||||||
@@ -343,7 +347,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||||
final_attention_mask |= image_to_overwrite
|
final_attention_mask |= image_to_overwrite
|
||||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||||
return final_embedding, final_attention_mask, position_ids
|
|
||||||
|
if labels is None:
|
||||||
|
final_labels = None
|
||||||
|
|
||||||
|
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
@@ -419,8 +427,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
image_features = torch.cat(image_features, dim=-1)
|
image_features = torch.cat(image_features, dim=-1)
|
||||||
|
|
||||||
image_features = self.multi_modal_projector(image_features)
|
image_features = self.multi_modal_projector(image_features)
|
||||||
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||||
image_features, inputs_embeds, input_ids, attention_mask, position_ids
|
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
if labels is None:
|
if labels is None:
|
||||||
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
|
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
@@ -332,3 +332,41 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Make sure that `generate` works
|
# Make sure that `generate` works
|
||||||
_ = model.generate(**inputs, max_new_tokens=20)
|
_ = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_llava_merge_inputs_error_bug(self):
|
||||||
|
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
|
||||||
|
model_id = "llava-hf/llava-1.5-7b-hf"
|
||||||
|
model = LlavaForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
# Simulate some user inputs
|
||||||
|
pixel_values = torch.randn(
|
||||||
|
(2, 3, 336, 336),
|
||||||
|
dtype=torch.float,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
input_ids = torch.tensor(
|
||||||
|
[
|
||||||
|
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
|
||||||
|
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
|
||||||
|
],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
attention_mask = torch.tensor(
|
||||||
|
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure that the loss is properly computed
|
||||||
|
loss = model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=input_ids,
|
||||||
|
).loss
|
||||||
|
loss.backward()
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
|
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
@@ -214,3 +214,41 @@ class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
EXPECTED_OUTPUT = "USER: <image> \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on"
|
EXPECTED_OUTPUT = "USER: <image> \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on"
|
||||||
self.assertEqual(processor.decode(outputs[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
self.assertEqual(processor.decode(outputs[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_vipllava_merge_inputs_error_bug(self):
|
||||||
|
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
|
||||||
|
model_id = "llava-hf/vip-llava-7b-hf"
|
||||||
|
model = VipLlavaForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
# Simulate some user inputs
|
||||||
|
pixel_values = torch.randn(
|
||||||
|
(2, 3, 336, 336),
|
||||||
|
dtype=torch.float,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
input_ids = torch.tensor(
|
||||||
|
[
|
||||||
|
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
|
||||||
|
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
|
||||||
|
],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
attention_mask = torch.tensor(
|
||||||
|
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure that the loss is properly computed
|
||||||
|
loss = model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=input_ids,
|
||||||
|
).loss
|
||||||
|
loss.backward()
|
||||||
|
|||||||
Reference in New Issue
Block a user