Correct llava mask & fix missing setter for vocab_size (#29389)
* correct llava mask * fix vipllava as wlel * mask out embedding for padding tokens * add test * fix style * add setter * fix test on suggestion
This commit is contained in:
@@ -147,6 +147,10 @@ class LlavaConfig(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
return self._vocab_size
|
return self._vocab_size
|
||||||
|
|
||||||
|
@vocab_size.setter
|
||||||
|
def vocab_size(self, value):
|
||||||
|
self._vocab_size = value
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
output = super().to_dict()
|
output = super().to_dict()
|
||||||
output.pop("_vocab_size", None)
|
output.pop("_vocab_size", None)
|
||||||
|
|||||||
@@ -344,6 +344,12 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
final_attention_mask |= image_to_overwrite
|
final_attention_mask |= image_to_overwrite
|
||||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||||
|
|
||||||
|
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||||
|
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||||
|
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||||
|
|
||||||
|
final_embedding[batch_indices, indices_to_mask] = 0
|
||||||
|
|
||||||
if labels is None:
|
if labels is None:
|
||||||
final_labels = None
|
final_labels = None
|
||||||
|
|
||||||
@@ -449,10 +455,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||||
|
|
||||||
# Get the target length
|
# Get the target length
|
||||||
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
target_length = input_ids.shape[1]
|
||||||
|
past_length = first_layer_past_key_value.shape[-1]
|
||||||
|
|
||||||
extended_attention_mask = torch.ones(
|
extended_attention_mask = torch.ones(
|
||||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
(attention_mask.shape[0], past_length),
|
||||||
dtype=attention_mask.dtype,
|
dtype=attention_mask.dtype,
|
||||||
device=attention_mask.device,
|
device=attention_mask.device,
|
||||||
)
|
)
|
||||||
@@ -467,7 +474,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
# Zero-out the places where we don't need to attend
|
# Zero-out the places where we don't need to attend
|
||||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||||
|
|
||||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
|
|||||||
@@ -356,7 +356,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||||
num_images, num_image_patches, embed_dim = image_features.shape
|
num_images, num_image_patches, embed_dim = image_features.shape
|
||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
|
|
||||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||||
# 1. Create a mask to know where special image tokens are
|
# 1. Create a mask to know where special image tokens are
|
||||||
special_image_token_mask = input_ids == self.config.image_token_index
|
special_image_token_mask = input_ids == self.config.image_token_index
|
||||||
@@ -418,6 +417,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
final_attention_mask |= image_to_overwrite
|
final_attention_mask |= image_to_overwrite
|
||||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||||
|
|
||||||
|
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||||
|
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||||
|
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||||
|
|
||||||
|
final_embedding[batch_indices, indices_to_mask] = 0
|
||||||
|
|
||||||
if labels is None:
|
if labels is None:
|
||||||
final_labels = None
|
final_labels = None
|
||||||
|
|
||||||
|
|||||||
@@ -347,6 +347,12 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
final_attention_mask |= image_to_overwrite
|
final_attention_mask |= image_to_overwrite
|
||||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||||
|
|
||||||
|
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||||
|
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||||
|
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||||
|
|
||||||
|
final_embedding[batch_indices, indices_to_mask] = 0
|
||||||
|
|
||||||
if labels is None:
|
if labels is None:
|
||||||
final_labels = None
|
final_labels = None
|
||||||
|
|
||||||
@@ -442,11 +448,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0)
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0)
|
||||||
|
|
||||||
# Get the target length
|
target_length = input_ids.shape[1]
|
||||||
target_seqlen = first_layer_past_key_value.shape[-2] + 1
|
past_length = first_layer_past_key_value.shape[-1]
|
||||||
|
|
||||||
extended_attention_mask = torch.ones(
|
extended_attention_mask = torch.ones(
|
||||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
(attention_mask.shape[0], past_length),
|
||||||
dtype=attention_mask.dtype,
|
dtype=attention_mask.dtype,
|
||||||
device=attention_mask.device,
|
device=attention_mask.device,
|
||||||
)
|
)
|
||||||
@@ -461,7 +467,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
# Zero-out the places where we don't need to attend
|
# Zero-out the places where we don't need to attend
|
||||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||||
|
|
||||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
|
|||||||
@@ -27,7 +27,14 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
require_bitsandbytes,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
@@ -470,10 +477,45 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
|
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
|
||||||
|
|
||||||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
def test_batched_generation(self):
|
||||||
|
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf").to(torch_device)
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||||
|
|
||||||
|
prompt1 = "<image>\n<image>\nUSER: What's the the difference of two images?\nASSISTANT:"
|
||||||
|
prompt2 = "<image>\nUSER: Describe the image.\nASSISTANT:"
|
||||||
|
prompt3 = "<image>\nUSER: Describe the image.\nASSISTANT:"
|
||||||
|
url1 = "https://images.unsplash.com/photo-1552053831-71594a27632d?q=80&w=3062&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
url2 = "https://images.unsplash.com/photo-1617258683320-61900b281ced?q=80&w=3087&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||||
|
image1 = Image.open(requests.get(url1, stream=True).raw)
|
||||||
|
image2 = Image.open(requests.get(url2, stream=True).raw)
|
||||||
|
|
||||||
|
inputs = processor(
|
||||||
|
text=[prompt1, prompt2, prompt3],
|
||||||
|
images=[image1, image2, image1, image2],
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
EXPECTED_OUTPUT = [
|
||||||
|
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one",
|
||||||
|
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
|
||||||
|
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
|
||||||
|
]
|
||||||
|
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
self.assertEqual(outputs, EXPECTED_OUTPUT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
def test_llava_index_error_bug(self):
|
def test_llava_index_error_bug(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user