Fix number of patch check for different vision feature select strategy (#32494)

* Fix number of patch check for different vision feature select strategy

* add test

---------

Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
Insu Jang
2024-09-17 03:33:07 -04:00
committed by GitHub
parent 18e1a9c719
commit bcf8946f0a
2 changed files with 32 additions and 2 deletions

View File

@@ -645,7 +645,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
def pack_image_features(self, image_features, image_sizes, image_newline=None): def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
""" """
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@@ -654,6 +654,8 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
List of image feature tensor, each contains all the visual feature of all patches. List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`) image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W). Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`) image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector. New line embedding vector.
Returns: Returns:
@@ -668,8 +670,14 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
base_image_feature = image_feature[0] base_image_feature = image_feature[0]
image_feature = image_feature[1:] image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
if height * width != base_image_feature.shape[0]:
if vision_feature_select_strategy == "default":
expected_num_patches = height * width
elif vision_feature_select_strategy == "full":
expected_num_patches = height * width + 1
if expected_num_patches != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.") raise ValueError("The number of patches is not consistent with the image size.")
num_patch_height, num_patch_width = get_anyres_image_grid_shape( num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx], image_sizes[image_idx],
self.config.image_grid_pinpoints, self.config.image_grid_pinpoints,
@@ -825,6 +833,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
image_features, feature_lens = self.pack_image_features( image_features, feature_lens = self.pack_image_features(
image_features, image_features,
image_sizes, image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline, image_newline=self.image_newline,
) )
if legacy_processing: if legacy_processing:

View File

@@ -620,3 +620,24 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
# check that both inputs are handled correctly and generate the same output # check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
@slow
@require_bitsandbytes
def test_small_model_integration_test_full_vision_state_selection(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)
# test that changing `strategy` won't error out
model.vision_feature_select_strategy = "full"
inputs = self.processor(self.prompt, self.image, return_tensors="pt")
# verify generation
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes' # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)