diff --git a/docs/source/en/model_doc/got_ocr2.md b/docs/source/en/model_doc/got_ocr2.md index a560f78269..33607033b4 100644 --- a/docs/source/en/model_doc/got_ocr2.md +++ b/docs/source/en/model_doc/got_ocr2.md @@ -44,13 +44,14 @@ The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2 ```python >>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) ->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") +>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True) >>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" ->>> inputs = processor(image, return_tensors="pt").to(device) +>>> inputs = processor(image, return_tensors="pt", device=device).to(device) >>> generate_ids = model.generate( ... **inputs, @@ -68,15 +69,16 @@ The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2 ```python >>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) ->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") +>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True) >>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" >>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" ->>> inputs = processor([image1, image2], return_tensors="pt").to(device) +>>> inputs = processor([image1, image2], return_tensors="pt", device=device).to(device) >>> generate_ids = model.generate( ... **inputs, @@ -96,13 +98,14 @@ GOT-OCR2 can also generate formatted text, such as markdown or LaTeX. Here is an ```python >>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) ->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") +>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True) >>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/latex.png" ->>> inputs = processor(image, return_tensors="pt", format=True).to(device) +>>> inputs = processor(image, return_tensors="pt", format=True, device=device).to(device) >>> generate_ids = model.generate( ... **inputs, @@ -124,14 +127,15 @@ Here is an example of how to process multiple pages at once: ```python >>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) ->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") +>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True) >>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page1.png" >>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page2.png" ->>> inputs = processor([image1, image2], return_tensors="pt", multi_page=True, format=True).to(device) +>>> inputs = processor([image1, image2], return_tensors="pt", multi_page=True, format=True, device=device).to(device) >>> generate_ids = model.generate( ... **inputs, @@ -153,13 +157,14 @@ Here is an example of how to process cropped patches: ```python >>> import torch >>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", torch_dtype=torch.bfloat16, device_map=device) ->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") +>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True) >>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png" ->>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3).to(device) +>>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3, device=device).to(device) >>> generate_ids = model.generate( ... **inputs, @@ -179,13 +184,14 @@ GOT supports interactive OCR, where the user can specify the region to be recogn ```python >>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) ->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") +>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True) >>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" ->>> inputs = processor(image, return_tensors="pt", color="green").to(device) # or box=[x1, y1, x2, y2] for coordinates (image pixels) +>>> inputs = processor(image, return_tensors="pt", color="green", device=device).to(device) # or box=[x1, y1, x2, y2] for coordinates (image pixels) >>> generate_ids = model.generate( ... **inputs, @@ -206,14 +212,15 @@ Here is an example of how to process sheet music: ```python >>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import torch >>> import verovio >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device) ->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf") +>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True) >>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/sheet_music.png" ->>> inputs = processor(image, return_tensors="pt", format=True).to(device) +>>> inputs = processor(image, return_tensors="pt", format=True, device=device).to(device) >>> generate_ids = model.generate( ... **inputs, @@ -258,6 +265,10 @@ alt="drawing" width="600"/> [[autodoc]] GotOcr2ImageProcessor +## GotOcr2ImageProcessorFast + +[[autodoc]] GotOcr2ImageProcessorFast + ## GotOcr2Processor [[autodoc]] GotOcr2Processor diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ed26829010..f05a7b3b2c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1330,6 +1330,7 @@ else: _import_structure["models.deit"].append("DeiTImageProcessorFast") _import_structure["models.depth_pro"].append("DepthProImageProcessorFast") _import_structure["models.detr"].append("DetrImageProcessorFast") + _import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast") _import_structure["models.llava"].append("LlavaImageProcessorFast") _import_structure["models.llava_next"].append("LlavaNextImageProcessorFast") _import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast") @@ -6526,6 +6527,7 @@ if TYPE_CHECKING: from .models.deit import DeiTImageProcessorFast from .models.depth_pro import DepthProImageProcessorFast from .models.detr import DetrImageProcessorFast + from .models.got_ocr2 import GotOcr2ImageProcessorFast from .models.llava import LlavaImageProcessorFast from .models.llava_next import LlavaNextImageProcessorFast from .models.llava_onevision import LlavaOnevisionImageProcessorFast diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 4942b8f39b..180d156359 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -88,7 +88,7 @@ else: ("fuyu", ("FuyuImageProcessor",)), ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("glpn", ("GLPNImageProcessor",)), - ("got_ocr2", ("GotOcr2ImageProcessor",)), + ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")), ("grounding-dino", ("GroundingDinoImageProcessor",)), ("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("hiera", ("BitImageProcessor",)), diff --git a/src/transformers/models/got_ocr2/__init__.py b/src/transformers/models/got_ocr2/__init__.py index 071a7ea740..00b6ccc53f 100644 --- a/src/transformers/models/got_ocr2/__init__.py +++ b/src/transformers/models/got_ocr2/__init__.py @@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure if TYPE_CHECKING: from .configuration_got_ocr2 import * from .image_processing_got_ocr2 import * + from .image_processing_got_ocr2_fast import * from .modeling_got_ocr2 import * from .processing_got_ocr2 import * diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py index 7f7a0d7ae4..d052f4a543 100644 --- a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py +++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py @@ -1,9 +1,3 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_got_ocr2.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. # @@ -18,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Image processor class for Got-OCR-2.""" from functools import lru_cache from typing import Dict, List, Optional, Tuple, Union @@ -27,11 +21,9 @@ import numpy as np from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( - _rescale_for_pil_conversion, convert_to_rgb, resize, to_channel_dimension_format, - to_pil_image, ) from ...image_utils import ( OPENAI_CLIP_MEAN, @@ -142,6 +134,15 @@ class GotOcr2ImageProcessor(BaseImageProcessor): size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` method. + crop_to_patches (`bool`, *optional*, defaults to `False`): + Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the + `preprocess` method. + min_patches (`int`, *optional*, defaults to 1): + The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method. + max_patches (`int`, *optional*, defaults to 12): + The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be overridden by the `resample` parameter in the `preprocess` method. @@ -172,6 +173,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor): self, do_resize: bool = True, size: Dict[str, int] = None, + crop_to_patches: bool = False, + min_patches: int = 1, + max_patches: int = 12, resample: PILImageResampling = PILImageResampling.BICUBIC, do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, @@ -187,6 +191,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor): self.do_resize = do_resize self.size = size + self.crop_to_patches = crop_to_patches + self.min_patches = min_patches + self.max_patches = max_patches self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor @@ -249,6 +256,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor): images: ImageInput, do_resize: Optional[bool] = None, size: Optional[Dict[str, int]] = None, + crop_to_patches: Optional[bool] = None, + min_patches: Optional[int] = None, + max_patches: Optional[int] = None, resample: PILImageResampling = None, do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, @@ -274,6 +284,14 @@ class GotOcr2ImageProcessor(BaseImageProcessor): `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest edge equal to `int(size["shortest_edge"] * (1333 / 800))`. + crop_to_patches (`bool`, *optional*, defaults to `self.crop_to_patches`): + Whether to crop the image to patches. + min_patches (`int`, *optional*, defaults to `self.min_patches`): + The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. + max_patches (`int`, *optional*, defaults to `self.max_patches`): + The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. resample (`PILImageResampling`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): @@ -308,6 +326,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor): - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ do_resize = do_resize if do_resize is not None else self.do_resize + crop_to_patches = crop_to_patches if crop_to_patches is not None else self.crop_to_patches + min_patches = min_patches if min_patches is not None else self.min_patches + max_patches = max_patches if max_patches is not None else self.max_patches resample = resample if resample is not None else self.resample do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor @@ -353,40 +374,52 @@ class GotOcr2ImageProcessor(BaseImageProcessor): # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) - if do_resize: + if crop_to_patches and max_patches > 1: images = [ - self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + self.crop_image_to_patches( + image, + min_patches=min_patches, + max_patches=max_patches, + patch_size=size, + data_format=input_data_format, + ) for image in images ] + num_patches = np.array([len(image) for image in images]) + images = [image for images_list in images for image in images_list] + else: + num_patches = np.array([1] * len(images)) - if do_rescale: - images = [ - self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - for image in images - ] + for i, image in enumerate(images): + if do_resize: + images[i] = self.resize(image, size=size, resample=resample, input_data_format=input_data_format) - if do_normalize: - images = [ - self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) - for image in images - ] + if do_rescale: + images[i] = self.rescale(image=images[i], scale=rescale_factor, input_data_format=input_data_format) - images = [ - to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images - ] + if do_normalize: + images[i] = self.normalize( + image=images[i], + mean=image_mean, + std=image_std, + input_data_format=input_data_format, + ) - encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + images[i] = to_channel_dimension_format(images[i], data_format, input_channel_dim=input_data_format) + + encoded_outputs = BatchFeature( + data={"pixel_values": images, "num_patches": num_patches}, tensor_type=return_tensors + ) return encoded_outputs def crop_image_to_patches( self, - image: ImageInput, + images: np.ndarray, min_patches: int, max_patches: int, use_thumbnail: bool = True, patch_size: Union[Tuple, int, dict] = None, - return_numpy: bool = False, data_format: ChannelDimension = None, ): """ @@ -396,8 +429,8 @@ class GotOcr2ImageProcessor(BaseImageProcessor): The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio. Args: - image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`): - The image to be cropped. The image can be a PIL image, NumPy array or PyTorch tensor. + images (`np.ndarray`): + The image to be cropped. min_patches (`int`): The minimum number of patches to be extracted from the image. max_patches (`int`): @@ -406,24 +439,17 @@ class GotOcr2ImageProcessor(BaseImageProcessor): Whether to add a thumbnail image to the list of cropped patches. patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*): The size of the output patches. - return_numpy (`bool`, *optional*, defaults to `False`): - Whether to return the cropped images as NumPy arrays. data_format (`ChannelDimension`, *optional*): The format of the image data. If `None`, the format is inferred from the input image. Returns: List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images. """ - patch_size = patch_size if patch_size is not None else self.size - patch_size = get_size_dict(patch_size, default_to_square=True) - original_size = get_size_dict(image.size, height_width_order=False) - do_rescale = False - if not isinstance(image, PIL.Image.Image): - do_rescale = _rescale_for_pil_conversion(image) - image = to_pil_image(image, do_rescale=do_rescale) - + if data_format is None: + data_format = infer_channel_dimension_format(images) + images = to_channel_dimension_format(images, ChannelDimension.FIRST, data_format) patch_size_height, patch_size_width = patch_size["height"], patch_size["width"] - original_height, original_width = original_size["height"], original_size["width"] + original_height, original_width = images.shape[-2:] # find the closest aspect ratio to the target num_columns, num_rows = get_optimal_tiled_canvas( (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches @@ -435,8 +461,12 @@ class GotOcr2ImageProcessor(BaseImageProcessor): num_blocks = num_columns * num_rows # resize the image so that each patch is of patch_size - resized_image = image.resize((target_width, target_height)) - + resized_image = self.resize( + images, + {"height": target_height, "width": target_width}, + data_format=ChannelDimension.FIRST, + input_data_format=ChannelDimension.FIRST, + ) # split the image into patches processed_images = [] for i in range(num_blocks): @@ -449,33 +479,16 @@ class GotOcr2ImageProcessor(BaseImageProcessor): (row + 1) * patch_size_height, ) # split the image - patch_image = resized_image.crop(box) + patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]] + patch_image = to_channel_dimension_format(patch_image, data_format, ChannelDimension.FIRST) processed_images.append(patch_image) if use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((patch_size_width, patch_size_height)) + thumbnail_img = self.resize( + images, patch_size, data_format=data_format, input_data_format=ChannelDimension.FIRST + ) processed_images.append(thumbnail_img) - if return_numpy: - processed_images_numpy = [] - for processed_image in processed_images: - processed_image = np.array(processed_image) - # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image - # so we need to add it back if necessary. - processed_image = ( - np.expand_dims(processed_image, axis=-1) if processed_image.ndim == 2 else processed_image - ) - # The image is always in channels last format after converting from a PIL image - if data_format is not None: - processed_image = to_channel_dimension_format( - processed_image, data_format, input_channel_dim=ChannelDimension.LAST - ) - # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to - # rescale it back to the original range. - processed_image = self.rescale(processed_image, 1 / 255) if do_rescale else processed_image - processed_images_numpy.append(processed_image) - processed_images = processed_images_numpy - return processed_images diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py new file mode 100644 index 0000000000..5103f73b11 --- /dev/null +++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py @@ -0,0 +1,257 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Got-OCR-2.""" + +from typing import List, Optional, Tuple, Union + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + DefaultFastImageProcessorInitKwargs, + DefaultFastImageProcessorPreprocessKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) +from .image_processing_got_ocr2 import get_optimal_tiled_canvas + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class GotOcr2ImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs): + crop_to_patches: Optional[bool] + min_patches: Optional[int] + max_patches: Optional[int] + + +class GotOcr2ImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs): + crop_to_patches: Optional[bool] + min_patches: Optional[int] + max_patches: Optional[int] + + +@add_start_docstrings( + "Constructs a fast GotOcr2 image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + crop_to_patches (`bool`, *optional*, defaults to `False`): + Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the + `preprocess` method. + min_patches (`int`, *optional*, defaults to 1): + The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method. + max_patches (`int`, *optional*, defaults to 12): + The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method. + """, +) +class GotOcr2ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + size = {"height": 384, "width": 384} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + crop_to_patches = False + min_patches = 1 + max_patches = 12 + valid_init_kwargs = GotOcr2ImageProcessorInitKwargs + valid_preprocess_kwargs = GotOcr2ImageProcessorPreprocessKwargs + + def __init__(self, **kwargs: Unpack[GotOcr2ImageProcessorInitKwargs]): + super().__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + crop_to_patches (`bool`, *optional*, defaults to `False`): + Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the + `preprocess` method. + min_patches (`int`, *optional*, defaults to 1): + The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method. + max_patches (`int`, *optional*, defaults to 12): + The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is + set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[GotOcr2ImageProcessorPreprocessKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def crop_image_to_patches( + self, + images: "torch.Tensor", + min_patches: int, + max_patches: int, + use_thumbnail: bool = True, + patch_size: Union[Tuple, int, dict] = None, + interpolation: Optional["F.InterpolationMode"] = None, + ): + """ + Crop the images to patches and return a list of cropped images. + The number of patches and their grid arrangement are determined by the original image size, + the target patch size and the minimum and maximum number of patches. + The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio. + + Args: + images (`torch.Tensor`): + The images to be cropped. + min_patches (`int`): + The minimum number of patches to be extracted from the image. + max_patches (`int`): + The maximum number of patches to be extracted from the image. + use_thumbnail (`bool`, *optional*, defaults to `True`): + Whether to add a thumbnail image to the list of cropped patches. + patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*): + The size of the output patches. + The format of the image data. If `None`, the format is inferred from the input image. + + Returns: + List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images. + """ + patch_size_height, patch_size_width = patch_size.height, patch_size.width + original_height, original_width = images.shape[-2:] + # find the closest aspect ratio to the target + num_columns, num_rows = get_optimal_tiled_canvas( + (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches + ) + + # calculate the target width and height + target_width = patch_size_width * num_columns + target_height = patch_size_height * num_rows + num_blocks = num_columns * num_rows + + # resize the image so that each patch is of patch_size + resized_image = self.resize( + images, SizeDict(height=target_height, width=target_width), interpolation=interpolation + ) + # split the image into patches + processed_images = [] + for i in range(num_blocks): + column = i % num_columns + row = i // num_columns + box = ( + column * patch_size_width, + row * patch_size_height, + (column + 1) * patch_size_width, + (row + 1) * patch_size_height, + ) + # split the image + patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]] + processed_images.append(patch_image) + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = self.resize(images, patch_size, interpolation=interpolation) + processed_images.append(thumbnail_img) + + processed_images = torch.stack(processed_images, dim=0).transpose(0, 1).contiguous() + + return processed_images + + def _preprocess( + self, + images: List["torch.Tensor"], + do_resize: bool, + size: SizeDict, + crop_to_patches: bool, + min_patches: int, + max_patches: int, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]], + ) -> BatchFeature: + if crop_to_patches: + grouped_images, grouped_images_index = group_images_by_shape(images) + processed_images_grouped = {} + num_patches = {} + for shape, stacked_images in grouped_images.items(): + stacked_images = self.crop_image_to_patches( + stacked_images, + min_patches, + max_patches, + patch_size=size, + interpolation=interpolation, + ) + processed_images_grouped[shape] = stacked_images + num_patches[shape] = [stacked_images.shape[1]] * stacked_images.shape[0] + images = reorder_images(processed_images_grouped, grouped_images_index) + images = [image for images_list in images for image in images_list] + num_patches = reorder_images(num_patches, grouped_images_index) + else: + num_patches = [1] * len(images) + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature( + data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors + ) + + +__all__ = ["GotOcr2ImageProcessorFast"] diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 7fbb0d39ef..918ee2bb03 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -32,11 +32,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ..auto import AutoModelForCausalLM from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index e8b0d770d3..2dc500b8d6 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -14,35 +14,20 @@ # limitations under the License. -from functools import lru_cache from typing import List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.utils.checkpoint -from transformers.models.blip.image_processing_blip import BlipImageProcessor from transformers.models.llava.modeling_llava import ( LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel, ) from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer -from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack -from transformers.tokenization_utils_base import ( - PreTokenizedInput, - TextInput, -) from ...configuration_utils import PretrainedConfig -from ...image_processing_utils import BatchFeature, get_size_dict -from ...image_transforms import ( - _rescale_for_pil_conversion, - to_channel_dimension_format, - to_pil_image, -) -from ...image_utils import ChannelDimension, ImageInput from ...utils import ( add_start_docstrings_to_model_forward, is_vision_available, @@ -53,9 +38,7 @@ from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForCausalLM if is_vision_available(): - import PIL - - from ...image_utils import load_images + pass logger = logging.get_logger(__name__) @@ -246,437 +229,6 @@ class GotOcr2Config(PretrainedConfig): __all__ = ["GotOcr2VisionConfig", "GotOcr2Config"] -class GotOcr2TextKwargs(TextKwargs, total=False): - format: Optional[bool] - - -class GotOcr2ImagesKwargs(ImagesKwargs, total=False): - box: Optional[Union[List, Tuple[float, float], Tuple[float, float, float, float]]] - color: Optional[str] - num_image_tokens: Optional[int] - multi_page: Optional[bool] - crop_to_patches: Optional[bool] - min_patches: Optional[int] - max_patches: Optional[int] - - -class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False): - text_kwargs: GotOcr2TextKwargs - images_kwargs: GotOcr2ImagesKwargs - _defaults = { - "text_kwargs": { - "padding": False, - "format": False, - }, - "images_kwargs": { - "num_image_tokens": 256, - "multi_page": False, - "crop_to_patches": False, - "min_patches": 1, - "max_patches": 12, - }, - } - - -def preprocess_box_annotation(box: Union[List, Tuple], image_size: Tuple[int, int]) -> List: - """ - Convert box annotation to the format [x1, y1, x2, y2] in the range [0, 1000]. - """ - width, height = image_size - if len(box) == 4: - box[0] = int(box[0] / width * 1000) - box[1] = int(box[1] / height * 1000) - box[2] = int(box[2] / width * 1000) - box[3] = int(box[3] / height * 1000) - else: - raise ValueError("Box must be a list or tuple of lists in the form [x1, y1, x2, y2].") - - return list(box) - - -# Similar to image_processing_mllama.get_all_supported_aspect_ratios -@lru_cache(maxsize=10) -def get_all_supported_aspect_ratios(min_image_tiles: int, max_image_tiles: int) -> List[Tuple[int, int]]: - """ - Computes all allowed aspect ratios for a given minimum and maximum number of input tiles. - - This function calculates all possible arrangements of tiles that can be formed - within the constraint of the minimum and maximum number of tiles. Each arrangement is - represented by its aspect ratio (width/height) and the corresponding tile configuration. - - Args: - min_image_tiles (`int`): - The minimum number of tiles allowed. - max_image_tiles (`int`): - The maximum number of tiles allowed. - - Returns: - `List[Tuple[int, int]]`: A list of tuples, each tuple representing a valid (width, height) - configuration in terms of number of tiles. - - Example: - >>> get_all_supported_aspect_ratios(1, 4) - [(1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (2, 2), (4, 1)] - - """ - aspect_ratios = [] - for width in range(1, max_image_tiles + 1): - for height in range(1, max_image_tiles + 1): - if width * height <= max_image_tiles and width * height >= min_image_tiles: - aspect_ratios.append((width, height)) - - aspect_ratios = sorted(aspect_ratios, key=lambda x: x[0] * x[1]) - - return aspect_ratios - - -@lru_cache(maxsize=100) -def get_optimal_tiled_canvas( - original_image_size: Tuple[int, int], - target_tile_size: Tuple[int, int], - min_image_tiles: int, - max_image_tiles: int, -) -> Tuple[int, int]: - """ - Given a minimum and maximum number of tiles, find the canvas with the closest aspect ratio to the - original image aspect ratio. - In case of tie-breaking condition when two canvases have the same aspect ratio difference, we favor the canvas with - more tiles, until the area covered by the tiles is more than twice the target area, in order to avoid unnecessarily - excessive tiling. - """ - possible_tile_arrangements = get_all_supported_aspect_ratios(min_image_tiles, max_image_tiles) - - original_height, original_width = original_image_size - target_tile_height, target_tile_width = target_tile_size - aspect_ratio = original_width / original_height - area = original_width * original_height - - # find the grid with the best aspect ratio - best_ratio_diff = float("inf") - best_grid = (1, 1) - for grid in possible_tile_arrangements: - grid_aspect_ratio = grid[0] / grid[1] - ratio_diff = abs(aspect_ratio - grid_aspect_ratio) - if ratio_diff < best_ratio_diff: - best_ratio_diff = ratio_diff - best_grid = grid - elif ratio_diff == best_ratio_diff: - # if the aspect ratio difference is the same, we favor the grid with more patches - # until the area covered by the patches is more than twice the original image area - if area > 0.5 * target_tile_height * target_tile_width * grid[0] * grid[1]: - best_grid = grid - - return best_grid - - -class GotOcr2ImageProcessor(BlipImageProcessor): - def crop_image_to_patches( - self, - image: ImageInput, - min_patches: int, - max_patches: int, - use_thumbnail: bool = True, - patch_size: Union[Tuple, int, dict] = None, - return_numpy: bool = False, - data_format: ChannelDimension = None, - ): - """ - Crop the image to patches and return a list of cropped images. - The number of patches and their grid arrangement are determined by the original image size, - the target patch size and the minimum and maximum number of patches. - The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio. - - Args: - image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`): - The image to be cropped. The image can be a PIL image, NumPy array or PyTorch tensor. - min_patches (`int`): - The minimum number of patches to be extracted from the image. - max_patches (`int`): - The maximum number of patches to be extracted from the image. - use_thumbnail (`bool`, *optional*, defaults to `True`): - Whether to add a thumbnail image to the list of cropped patches. - patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*): - The size of the output patches. - return_numpy (`bool`, *optional*, defaults to `False`): - Whether to return the cropped images as NumPy arrays. - data_format (`ChannelDimension`, *optional*): - The format of the image data. If `None`, the format is inferred from the input image. - - Returns: - List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images. - """ - patch_size = patch_size if patch_size is not None else self.size - patch_size = get_size_dict(patch_size, default_to_square=True) - original_size = get_size_dict(image.size, height_width_order=False) - do_rescale = False - if not isinstance(image, PIL.Image.Image): - do_rescale = _rescale_for_pil_conversion(image) - image = to_pil_image(image, do_rescale=do_rescale) - - patch_size_height, patch_size_width = patch_size["height"], patch_size["width"] - original_height, original_width = original_size["height"], original_size["width"] - # find the closest aspect ratio to the target - num_columns, num_rows = get_optimal_tiled_canvas( - (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches - ) - - # calculate the target width and height - target_width = patch_size_width * num_columns - target_height = patch_size_height * num_rows - num_blocks = num_columns * num_rows - - # resize the image so that each patch is of patch_size - resized_image = image.resize((target_width, target_height)) - - # split the image into patches - processed_images = [] - for i in range(num_blocks): - column = i % num_columns - row = i // num_columns - box = ( - column * patch_size_width, - row * patch_size_height, - (column + 1) * patch_size_width, - (row + 1) * patch_size_height, - ) - # split the image - patch_image = resized_image.crop(box) - processed_images.append(patch_image) - - if use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((patch_size_width, patch_size_height)) - processed_images.append(thumbnail_img) - - if return_numpy: - processed_images_numpy = [] - for processed_image in processed_images: - processed_image = np.array(processed_image) - # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image - # so we need to add it back if necessary. - processed_image = ( - np.expand_dims(processed_image, axis=-1) if processed_image.ndim == 2 else processed_image - ) - # The image is always in channels last format after converting from a PIL image - if data_format is not None: - processed_image = to_channel_dimension_format( - processed_image, data_format, input_channel_dim=ChannelDimension.LAST - ) - # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to - # rescale it back to the original range. - processed_image = self.rescale(processed_image, 1 / 255) if do_rescale else processed_image - processed_images_numpy.append(processed_image) - processed_images = processed_images_numpy - - return processed_images - - -class GotOcr2Processor(ProcessorMixin): - r""" - Constructs a GotOcr2 processor which wraps a [`GotOcr2ImageProcessor`] and - [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and - tokenizer functionalities. See the [`~GotOcr2Processor.__call__`] and [`~GotOcr2Processor.decode`] for more information. - Args: - image_processor ([`GotOcr2ImageProcessor`], *optional*): - The image processor is a required input. - tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*): - The tokenizer is a required input. - chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages - in a chat into a tokenizable string. - """ - - attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] - image_processor_class = "GotOcr2ImageProcessor" - tokenizer_class = "PreTrainedTokenizerFast" - - def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): - super().__init__(image_processor, tokenizer, chat_template=chat_template) - - self.message_start_token = "<|im_start|>" - self.message_end_token = "<|im_end|>" - self.img_start_token = "" - self.img_end_token = "" - self.img_pad_token = "" - self.system_query = "system\nYou should follow the instructions carefully and explain your answers in detail." - - def _make_list_of_inputs(self, images, text, box, color, multi_page): - if not isinstance(images, (list, tuple)): - images = [images] - if multi_page: - logger.warning("Multi-page inference is enabled but only one image is passed.") - images = [images] - elif isinstance(images[0], (list, tuple)) and not multi_page: - raise ValueError("Nested images are only supported with `multi_page` set to `True`.") - elif not isinstance(images[0], (list, tuple)) and multi_page: - images = [images] - - if isinstance(text, str): - text = [text] - - if not isinstance(box[0], (list, tuple)): - # Use the same box for all images - box = [box for _ in range(len(images))] - if not isinstance(color, (list, tuple)): - color = [color for _ in range(len(images))] - - return images, text, box, color - - def __call__( - self, - images: Optional[ImageInput] = None, - text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, - audio=None, - videos=None, - **kwargs: Unpack[GotOcr2ProcessorKwargs], - ) -> BatchFeature: - """ - Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` - and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text` - is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and - `crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwrags` arguments to - GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`. - - Args: - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. Both channels-first and channels-last formats are supported. - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - format (`bool`, *optional*): - If set, will add the format token to the query, and the model will return the OCR result with formatting. - box (`List[float]`, `List[Tuple[float, float]]`, `List[Tuple[float, float, float, float]]`, *optional*): - The box annotation to be added to the query. If a list of floats or a tuple of floats is provided, it - will be interpreted as [x1, y1, x2, y2]. If a list of tuples is provided, each tuple should be in the - form (x1, y1, x2, y2). - color (`str`, *optional*): - The color annotation to be added to the query. The model will return the OCR result within the box with - the specified color. - multi_page (`bool`, *optional*): - If set, will enable multi-page inference. The model will return the OCR result across multiple pages. - crop_to_patches (`bool`, *optional*): - If set, will crop the image to patches. The model will return the OCR result upon the patch reference. - min_patches (`int`, *optional*): - The minimum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to - `True`. - max_patches (`int`, *optional*): - The maximum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to - `True`. - - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - - Returns: - [`BatchFeature`]: A [`BatchFeature`] with the following fields: - - - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when - `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not - `None`). - - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - """ - - output_kwargs = self._merge_kwargs( - GotOcr2ProcessorKwargs, - tokenizer_init_kwargs=self.tokenizer.init_kwargs, - **kwargs, - ) - format_output = output_kwargs["text_kwargs"].pop("format") - num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens") - box = output_kwargs["images_kwargs"].pop("box", [None]) - color = output_kwargs["images_kwargs"].pop("color", None) - multi_page = output_kwargs["images_kwargs"].pop("multi_page") - crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches") - min_patches = output_kwargs["images_kwargs"].pop("min_patches") - max_patches = output_kwargs["images_kwargs"].pop("max_patches") - - images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page) - - # Load images as we need to know the image size - images = load_images(images) - if text is None: - text = [] - for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)): - if crop_to_patches: - image_group = self.image_processor.crop_image_to_patches( - image_group, - patch_size=output_kwargs["images_kwargs"].get("size"), - min_patches=min_patches, - max_patches=max_patches, - ) - images[index] = image_group - num_images = len(image_group) if (multi_page or crop_to_patches) else 1 - if box_single[0] is not None: - box_single = preprocess_box_annotation(box_single, image_group.size) - query = ( - f"{f'[{color_single}] ' if color_single is not None else ''}" - f"{str(box_single) if box_single[0] is not None else ''} " - "OCR" - f"{' with format' if format_output else ''}" - f"{' across multi pages' if multi_page else ''}" - f"{' upon the patch reference' if crop_to_patches else ''}" - ": " - ) - prompt = ( - self.message_start_token - + self.system_query - + self.message_end_token - + self.message_start_token - + "user\n" - + self.img_start_token - + self.img_pad_token * num_image_tokens * num_images - + self.img_end_token - + "\n" - + query - + self.message_end_token - + self.message_start_token - + "assistant\n" - ) - text.append(prompt) - elif crop_to_patches: - for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)): - image_group = self.image_processor.crop_image_to_patches( - image_group, - patch_size=output_kwargs["images_kwargs"].get("size"), - min_patches=min_patches, - max_patches=max_patches, - ) - images[index] = image_group - - text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) - if multi_page or crop_to_patches: - # flatten images - images = [image for image_group in images for image in image_group] - image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) - - return BatchFeature(data={**text_inputs, **image_inputs}) - - def batch_decode(self, *args, **kwargs): - """ - This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please - refer to the docstring of this method for more information. - """ - return self.tokenizer.batch_decode(*args, **kwargs) - - def decode(self, *args, **kwargs): - """ - This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to - the docstring of this method for more information. - """ - return self.tokenizer.decode(*args, **kwargs) - - @property - def model_input_names(self): - tokenizer_input_names = self.tokenizer.model_input_names - image_processor_input_names = self.image_processor.model_input_names - return list(tokenizer_input_names) + list(image_processor_input_names) - - class GotOcr2MLPBlock(SamMLPBlock): pass @@ -972,8 +524,6 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): __all__ = [ "GotOcr2VisionConfig", "GotOcr2Config", - "GotOcr2Processor", "GotOcr2PreTrainedModel", "GotOcr2ForConditionalGeneration", - "GotOcr2ImageProcessor", ] diff --git a/src/transformers/models/got_ocr2/processing_got_ocr2.py b/src/transformers/models/got_ocr2/processing_got_ocr2.py index 636db765f9..398ec36c9e 100644 --- a/src/transformers/models/got_ocr2/processing_got_ocr2.py +++ b/src/transformers/models/got_ocr2/processing_got_ocr2.py @@ -1,9 +1,3 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_got_ocr2.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. # @@ -22,6 +16,8 @@ from typing import List, Optional, Tuple, Union +import numpy as np + from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput @@ -100,7 +96,7 @@ class GotOcr2Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] valid_kwargs = ["chat_template"] - image_processor_class = "GotOcr2ImageProcessor" + image_processor_class = "AutoImageProcessor" tokenizer_class = "PreTrainedTokenizerFast" def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): @@ -205,28 +201,29 @@ class GotOcr2Processor(ProcessorMixin): box = output_kwargs["images_kwargs"].pop("box", [None]) color = output_kwargs["images_kwargs"].pop("color", None) multi_page = output_kwargs["images_kwargs"].pop("multi_page") - crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches") - min_patches = output_kwargs["images_kwargs"].pop("min_patches") - max_patches = output_kwargs["images_kwargs"].pop("max_patches") + crop_to_patches = output_kwargs["images_kwargs"].get("crop_to_patches") images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page) - + if multi_page: + # save the number of pages per batch + num_pages_per_batch = [len(image_group) for image_group in images] + # flatten the list of images + images = [image for image_group in images for image in image_group] + else: + num_pages_per_batch = [1 for _ in range(len(images))] # Load images as we need to know the image size images = load_images(images) + image_sizes = [image.size for image in images] + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + num_patches_array = image_inputs.pop("num_patches") if text is None: text = [] - for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)): - if crop_to_patches: - image_group = self.image_processor.crop_image_to_patches( - image_group, - patch_size=output_kwargs["images_kwargs"].get("size"), - min_patches=min_patches, - max_patches=max_patches, - ) - images[index] = image_group - num_images = len(image_group) if (multi_page or crop_to_patches) else 1 + patch_indices = np.cumsum(num_pages_per_batch) + for index, (num_pages, box_single, color_single) in enumerate(zip(num_pages_per_batch, box, color)): + current_patch_index = patch_indices[index - 1] if index > 0 else 0 + num_patches = sum(num_patches_array[current_patch_index : current_patch_index + num_pages]) if box_single[0] is not None: - box_single = preprocess_box_annotation(box_single, image_group.size) + box_single = preprocess_box_annotation(box_single, image_sizes[index]) query = ( f"{f'[{color_single}] ' if color_single is not None else ''}" f"{str(box_single) if box_single[0] is not None else ''} " @@ -243,7 +240,7 @@ class GotOcr2Processor(ProcessorMixin): + self.message_start_token + "user\n" + self.img_start_token - + self.img_pad_token * num_image_tokens * num_images + + self.img_pad_token * num_image_tokens * num_patches + self.img_end_token + "\n" + query @@ -252,22 +249,8 @@ class GotOcr2Processor(ProcessorMixin): + "assistant\n" ) text.append(prompt) - elif crop_to_patches: - for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)): - image_group = self.image_processor.crop_image_to_patches( - image_group, - patch_size=output_kwargs["images_kwargs"].get("size"), - min_patches=min_patches, - max_patches=max_patches, - ) - images[index] = image_group text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) - if multi_page or crop_to_patches: - # flatten images - images = [image for image_group in images for image in image_group] - image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) - return BatchFeature(data={**text_inputs, **image_inputs}) def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index f393a8f126..de62c4ae7c 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -58,6 +58,13 @@ class DetrImageProcessorFast(metaclass=DummyObject): requires_backends(self, ["torchvision"]) +class GotOcr2ImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class LlavaImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] diff --git a/tests/models/got_ocr2/test_image_processing_got_ocr2.py b/tests/models/got_ocr2/test_image_processing_got_ocr2.py index c4e75feee6..93cd347dea 100644 --- a/tests/models/got_ocr2/test_image_processing_got_ocr2.py +++ b/tests/models/got_ocr2/test_image_processing_got_ocr2.py @@ -16,15 +16,22 @@ import unittest +from transformers.image_utils import SizeDict from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs +if is_torch_available(): + import torch + if is_vision_available(): from transformers import GotOcr2ImageProcessor + if is_torchvision_available(): + from transformers import GotOcr2ImageProcessorFast + class GotOcr2ImageProcessingTester(unittest.TestCase): def __init__( @@ -89,6 +96,7 @@ class GotOcr2ImageProcessingTester(unittest.TestCase): @require_vision class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = GotOcr2ImageProcessor if is_vision_available() else None + fast_image_processing_class = GotOcr2ImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -99,17 +107,72 @@ class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processor = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processor, "do_resize")) - self.assertTrue(hasattr(image_processor, "size")) - self.assertTrue(hasattr(image_processor, "do_normalize")) - self.assertTrue(hasattr(image_processor, "image_mean")) - self.assertTrue(hasattr(image_processor, "image_std")) - self.assertTrue(hasattr(image_processor, "do_convert_rgb")) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_resize")) + self.assertTrue(hasattr(image_processor, "size")) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "image_mean")) + self.assertTrue(hasattr(image_processor, "image_std")) + self.assertTrue(hasattr(image_processor, "do_convert_rgb")) + + def test_slow_fast_equivalence_crop_to_patches(self): + dummy_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)[0] + + image_processor_slow = self.image_processing_class(**self.image_processor_dict, crop_to_patches=True) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, crop_to_patches=True) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + + torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches) + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 + ) + + def test_slow_fast_equivalence_batched_crop_to_patches(self): + # Prepare image inputs so that we have two groups of images with equal resolution with a group of images with + # different resolutions in between + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + dummy_images += self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + dummy_images += self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + + image_processor_slow = self.image_processing_class(**self.image_processor_dict, crop_to_patches=True) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, crop_to_patches=True) + + encoding_slow = image_processor_slow(dummy_images, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, return_tensors="pt") + + torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches) + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 + ) def test_crop_to_patches(self): - image_processor = self.image_processing_class(**self.image_processor_dict) - image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)[0] - processed_images = image_processor.crop_image_to_patches(image, 1, 6, use_thumbnail=True) + # test slow image processor + image_processor = self.image_processor_list[0](**self.image_processor_dict) + image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)[0] + processed_images = image_processor.crop_image_to_patches( + image, + min_patches=1, + max_patches=6, + use_thumbnail=True, + patch_size={"height": 20, "width": 20}, + ) self.assertEqual(len(processed_images), 5) - self.assertEqual(processed_images[0].size, (20, 20)) + self.assertEqual(processed_images[0].shape[:2], (20, 20)) + + # test fast image processor (process batch) + image_processor = self.image_processor_list[1](**self.image_processor_dict) + image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)[0] + processed_images = image_processor.crop_image_to_patches( + image.unsqueeze(0), + min_patches=1, + max_patches=6, + use_thumbnail=True, + patch_size=SizeDict(height=20, width=20), + ) + self.assertEqual(len(processed_images[0]), 5) + self.assertEqual(processed_images.shape[-2:], (20, 20))