Add support for nested images to LLava and VipLLava (#35558)
* move make_flat_list_of_images and make_batched_videos to image_utils * remove unnecessary is_vision_available * move make_nested_list_of_images to image_utils * fix fast pixtral image processor * fix import mllama * fix make_nested_list_of_images * add tests * convert 4d arrays/tensors to list * add test_make_batched_videos * add support nested batch of videos * fix image processing qwen2vl
This commit is contained in:
@@ -31,7 +31,7 @@ from ...image_utils import (
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_valid_image,
|
||||
make_flat_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
@@ -39,29 +39,6 @@ from ...image_utils import (
|
||||
from ...utils import TensorType
|
||||
|
||||
|
||||
def make_batched_images(images) -> List[List[ImageInput]]:
|
||||
"""
|
||||
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
||||
|
||||
Args:
|
||||
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
||||
The input image.
|
||||
|
||||
Returns:
|
||||
list: A list of images.
|
||||
"""
|
||||
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
||||
return [img for img_list in images for img in img_list]
|
||||
|
||||
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
||||
return images
|
||||
|
||||
elif is_valid_image(images):
|
||||
return [images]
|
||||
|
||||
raise ValueError(f"Could not make batched video from {images}")
|
||||
|
||||
|
||||
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
|
||||
"""
|
||||
Divides an image into patches of a specified size.
|
||||
@@ -244,7 +221,7 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
if max_image_size not in [490, 980]:
|
||||
raise ValueError("max_image_size must be either 490 or 980")
|
||||
|
||||
images = make_batched_images(images)
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
|
||||
@@ -28,6 +28,7 @@ from ...image_utils import (
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
make_flat_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
@@ -58,7 +59,7 @@ from ..llama.modeling_llama import (
|
||||
LlamaRMSNorm,
|
||||
)
|
||||
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast
|
||||
from ..llava_next.image_processing_llava_next import divide_to_patches, make_batched_images
|
||||
from ..llava_next.image_processing_llava_next import divide_to_patches
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -609,7 +610,7 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
if max_image_size not in [490, 980]:
|
||||
raise ValueError("max_image_size must be either 490 or 980")
|
||||
|
||||
images = make_batched_images(images)
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user