Fix number of patch check for different vision feature select strategy (#32494)
* Fix number of patch check for different vision feature select strategy * add test --------- Co-authored-by: raushan <raushan@huggingface.co>
This commit is contained in:
@@ -645,7 +645,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
|
|
||||||
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
|
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
|
||||||
|
|
||||||
def pack_image_features(self, image_features, image_sizes, image_newline=None):
|
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
|
||||||
"""
|
"""
|
||||||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||||
|
|
||||||
@@ -654,6 +654,8 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
List of image feature tensor, each contains all the visual feature of all patches.
|
List of image feature tensor, each contains all the visual feature of all patches.
|
||||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||||
Actual image size of each images (H, W).
|
Actual image size of each images (H, W).
|
||||||
|
vision_feature_select_strategy (`str`)
|
||||||
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||||||
New line embedding vector.
|
New line embedding vector.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -668,8 +670,14 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
base_image_feature = image_feature[0]
|
base_image_feature = image_feature[0]
|
||||||
image_feature = image_feature[1:]
|
image_feature = image_feature[1:]
|
||||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||||
if height * width != base_image_feature.shape[0]:
|
|
||||||
|
if vision_feature_select_strategy == "default":
|
||||||
|
expected_num_patches = height * width
|
||||||
|
elif vision_feature_select_strategy == "full":
|
||||||
|
expected_num_patches = height * width + 1
|
||||||
|
if expected_num_patches != base_image_feature.shape[0]:
|
||||||
raise ValueError("The number of patches is not consistent with the image size.")
|
raise ValueError("The number of patches is not consistent with the image size.")
|
||||||
|
|
||||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||||
image_sizes[image_idx],
|
image_sizes[image_idx],
|
||||||
self.config.image_grid_pinpoints,
|
self.config.image_grid_pinpoints,
|
||||||
@@ -825,6 +833,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
image_features, feature_lens = self.pack_image_features(
|
image_features, feature_lens = self.pack_image_features(
|
||||||
image_features,
|
image_features,
|
||||||
image_sizes,
|
image_sizes,
|
||||||
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||||
image_newline=self.image_newline,
|
image_newline=self.image_newline,
|
||||||
)
|
)
|
||||||
if legacy_processing:
|
if legacy_processing:
|
||||||
|
|||||||
@@ -620,3 +620,24 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# check that both inputs are handled correctly and generate the same output
|
# check that both inputs are handled correctly and generate the same output
|
||||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_small_model_integration_test_full_vision_state_selection(self):
|
||||||
|
model = LlavaNextForConditionalGeneration.from_pretrained(
|
||||||
|
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
# test that changing `strategy` won't error out
|
||||||
|
model.vision_feature_select_strategy = "full"
|
||||||
|
|
||||||
|
inputs = self.processor(self.prompt, self.image, return_tensors="pt")
|
||||||
|
|
||||||
|
# verify generation
|
||||||
|
output = model.generate(**inputs, max_new_tokens=30)
|
||||||
|
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes' # fmt: skip
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
self.processor.decode(output[0], skip_special_tokens=True),
|
||||||
|
EXPECTED_DECODED_TEXT,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user