Add support for nested images to LLava and VipLLava (#35558)

* move make_flat_list_of_images and make_batched_videos to image_utils

* remove unnecessary is_vision_available

* move make_nested_list_of_images to image_utils

* fix fast pixtral image processor

* fix import mllama

* fix make_nested_list_of_images

* add tests

* convert 4d arrays/tensors to list

* add test_make_batched_videos

* add support nested batch of videos

* fix image processing qwen2vl
This commit is contained in:
Yoni Gozlan
2025-01-30 16:49:20 -05:00
committed by GitHub
parent e4227eb4d4
commit d7188ba600
27 changed files with 506 additions and 485 deletions

View File

@@ -31,7 +31,7 @@ from ...image_utils import (
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_valid_image,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
@@ -39,29 +39,6 @@ from ...image_utils import (
from ...utils import TensorType
def make_batched_images(images) -> List[List[ImageInput]]:
"""
Accepts images in list or nested list format, and makes a list of images for preprocessing.
Args:
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of images.
"""
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
return [img for img_list in images for img in img_list]
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
return images
elif is_valid_image(images):
return [images]
raise ValueError(f"Could not make batched video from {images}")
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
"""
Divides an image into patches of a specified size.
@@ -244,7 +221,7 @@ class AriaImageProcessor(BaseImageProcessor):
if max_image_size not in [490, 980]:
raise ValueError("max_image_size must be either 490 or 980")
images = make_batched_images(images)
images = make_flat_list_of_images(images)
if not valid_images(images):
raise ValueError(

View File

@@ -28,6 +28,7 @@ from ...image_utils import (
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
@@ -58,7 +59,7 @@ from ..llama.modeling_llama import (
LlamaRMSNorm,
)
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast
from ..llava_next.image_processing_llava_next import divide_to_patches, make_batched_images
from ..llava_next.image_processing_llava_next import divide_to_patches
logger = logging.get_logger(__name__)
@@ -609,7 +610,7 @@ class AriaImageProcessor(BaseImageProcessor):
if max_image_size not in [490, 980]:
raise ValueError("max_image_size must be either 490 or 980")
images = make_batched_images(images)
images = make_flat_list_of_images(images)
if not valid_images(images):
raise ValueError(