Fix Llava for 0-embeddings (#30473)
This commit is contained in:
committed by
GitHub
parent
ad697f1801
commit
e60491adc9
@@ -327,8 +327,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||||
|
|
||||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
image_to_overwrite = torch.full(
|
||||||
|
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||||
|
|
||||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||||
|
|||||||
@@ -403,8 +403,11 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||||
|
|
||||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
image_to_overwrite = torch.full(
|
||||||
|
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||||
|
|
||||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||||
|
|||||||
@@ -331,8 +331,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||||
|
|
||||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
image_to_overwrite = torch.full(
|
||||||
|
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||||
|
|
||||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||||
|
|||||||
@@ -459,3 +459,29 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
|
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
|
||||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_small_model_integration_test_unk_token(self):
|
||||||
|
# related to (#29835)
|
||||||
|
model = LlavaNextForConditionalGeneration.from_pretrained(
|
||||||
|
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_with_unk = "[INST] <image>\nWhat is shown in this <unk> image? [/INST]"
|
||||||
|
inputs = self.processor(prompt_with_unk, self.image, return_tensors="pt")
|
||||||
|
|
||||||
|
# verify single forward pass
|
||||||
|
inputs = inputs.to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs)
|
||||||
|
|
||||||
|
# verify generation
|
||||||
|
output = model.generate(**inputs, max_new_tokens=40)
|
||||||
|
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart' # fmt: skip
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
self.processor.decode(output[0], skip_special_tokens=True),
|
||||||
|
EXPECTED_DECODED_TEXT,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user