VLMs: fix number of image tokens (#34332)

* fix

* fix tests

* add tests

* style

* style

* fix qwen after rebase

* fix video llava
This commit is contained in:
Raushan Turganbay
2024-10-30 10:21:37 +01:00
committed by GitHub
parent 0f764a5af7
commit 913330ca9f
15 changed files with 237 additions and 15 deletions

View File

@@ -123,9 +123,9 @@ class VideoLlavaVisionText2TextModelTester:
self.batch_size = 5
self.num_channels = 3
self.image_size = 224
self.encoder_seq_length = 64
self.encoder_seq_length = 246
self.num_image_tokens = 25
self.num_video_tokens = 26
self.num_video_tokens = 26 * self.num_frames
self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens
def get_config(self):
@@ -267,7 +267,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
# if we remove some images from inputs leaving only one
# image number mismatch error should raise
inputs["pixel_values_images"] = inputs["pixel_values_images"][:1]
with self.assertRaises(RuntimeError):
with self.assertRaises(ValueError):
_ = model(**inputs)
def test_video_only_input(self):
@@ -401,6 +401,35 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values_images"] = input_dict["pixel_values_images"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values_images"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values_images=pixel_values)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values_images=pixel_values)
@require_torch
class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):