LLaVa-Next: Update docs with batched inference (#30857)
* update docs with batch ex * Update docs/source/en/model_doc/llava_next.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * accept nested list of img --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
cd6bd0af34
commit
5d0bf59b4d
@@ -68,6 +68,8 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
|
|||||||
|
|
||||||
## Usage example
|
## Usage example
|
||||||
|
|
||||||
|
### Single image inference
|
||||||
|
|
||||||
Here's how to load the model and perform inference in half-precision (`torch.float16`):
|
Here's how to load the model and perform inference in half-precision (`torch.float16`):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -94,6 +96,45 @@ output = model.generate(**inputs, max_new_tokens=100)
|
|||||||
print(processor.decode(output[0], skip_special_tokens=True))
|
print(processor.decode(output[0], skip_special_tokens=True))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Multi image inference
|
||||||
|
|
||||||
|
LLaVa-Next can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, LlavaNextForConditionalGeneration
|
||||||
|
|
||||||
|
# Load the model in half-precision
|
||||||
|
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, device_map="auto")
|
||||||
|
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
||||||
|
|
||||||
|
# Get three different images
|
||||||
|
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
image_stop = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
image_cats = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||||
|
image_snowman = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
# Prepare a batched prompt, where the first one is a multi-turn conversation and the second is not
|
||||||
|
prompt = [
|
||||||
|
"[INST] <image>\nWhat is shown in this image? [/INST] There is a red stop sign in the image. [INST] <image>\nWhat about this image? How many cats do you see [/INST]",
|
||||||
|
"[INST] <image>\nWhat is shown in this image? [/INST]"
|
||||||
|
]
|
||||||
|
|
||||||
|
# We can simply feed images in the order they have to be used in the text prompt
|
||||||
|
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
|
||||||
|
inputs = processor(text=prompt, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=30)
|
||||||
|
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
```
|
||||||
|
|
||||||
## Model optimization
|
## Model optimization
|
||||||
|
|
||||||
### Quantization using Bitsandbytes
|
### Quantization using Bitsandbytes
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from ...image_utils import (
|
|||||||
get_image_size,
|
get_image_size,
|
||||||
infer_channel_dimension_format,
|
infer_channel_dimension_format,
|
||||||
is_scaled_image,
|
is_scaled_image,
|
||||||
|
is_valid_image,
|
||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
@@ -52,6 +53,29 @@ if is_vision_available():
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def make_batched_images(images) -> List[List[ImageInput]]:
|
||||||
|
"""
|
||||||
|
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
||||||
|
The input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of images.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
||||||
|
return [img for img_list in images for img in img_list]
|
||||||
|
|
||||||
|
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
||||||
|
return images
|
||||||
|
|
||||||
|
elif is_valid_image(images):
|
||||||
|
return [images]
|
||||||
|
|
||||||
|
raise ValueError(f"Could not make batched video from {images}")
|
||||||
|
|
||||||
|
|
||||||
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
|
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
|
||||||
"""
|
"""
|
||||||
Divides an image into patches of a specified size.
|
Divides an image into patches of a specified size.
|
||||||
@@ -651,7 +675,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
|||||||
do_pad = do_pad if do_pad is not None else self.do_pad
|
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||||
|
|
||||||
images = make_list_of_images(images)
|
images = make_batched_images(images)
|
||||||
|
|
||||||
if not valid_images(images):
|
if not valid_images(images):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -199,3 +199,21 @@ class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
@unittest.skip("LlavaNextImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
@unittest.skip("LlavaNextImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||||
def test_call_numpy_4_channels(self):
|
def test_call_numpy_4_channels(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_nested_input(self):
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
||||||
|
|
||||||
|
# Test batched as a list of images
|
||||||
|
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = (7, 1445, 3, 18, 18)
|
||||||
|
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||||
|
|
||||||
|
# Test batched as a nested list of images, where each sublist is one batch
|
||||||
|
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
|
||||||
|
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = (7, 1445, 3, 18, 18)
|
||||||
|
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
|
||||||
|
|
||||||
|
# Image processor should return same pixel values, independently of ipnut format
|
||||||
|
self.assertTrue((encoded_images_nested == encoded_images).all())
|
||||||
|
|||||||
Reference in New Issue
Block a user