diff --git a/docs/source/en/model_doc/got_ocr2.md b/docs/source/en/model_doc/got_ocr2.md
index a560f78269..33607033b4 100644
--- a/docs/source/en/model_doc/got_ocr2.md
+++ b/docs/source/en/model_doc/got_ocr2.md
@@ -44,13 +44,14 @@ The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
+>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
+>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
->>> inputs = processor(image, return_tensors="pt").to(device)
+>>> inputs = processor(image, return_tensors="pt", device=device).to(device)
>>> generate_ids = model.generate(
... **inputs,
@@ -68,15 +69,16 @@ The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
+>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
+>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
->>> inputs = processor([image1, image2], return_tensors="pt").to(device)
+>>> inputs = processor([image1, image2], return_tensors="pt", device=device).to(device)
>>> generate_ids = model.generate(
... **inputs,
@@ -96,13 +98,14 @@ GOT-OCR2 can also generate formatted text, such as markdown or LaTeX. Here is an
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
+>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
+>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/latex.png"
->>> inputs = processor(image, return_tensors="pt", format=True).to(device)
+>>> inputs = processor(image, return_tensors="pt", format=True, device=device).to(device)
>>> generate_ids = model.generate(
... **inputs,
@@ -124,14 +127,15 @@ Here is an example of how to process multiple pages at once:
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
+>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
+>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page1.png"
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page2.png"
->>> inputs = processor([image1, image2], return_tensors="pt", multi_page=True, format=True).to(device)
+>>> inputs = processor([image1, image2], return_tensors="pt", multi_page=True, format=True, device=device).to(device)
>>> generate_ids = model.generate(
... **inputs,
@@ -153,13 +157,14 @@ Here is an example of how to process cropped patches:
```python
>>> import torch
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
+>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", torch_dtype=torch.bfloat16, device_map=device)
->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
+>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png"
->>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3).to(device)
+>>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3, device=device).to(device)
>>> generate_ids = model.generate(
... **inputs,
@@ -179,13 +184,14 @@ GOT supports interactive OCR, where the user can specify the region to be recogn
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
+>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
+>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
->>> inputs = processor(image, return_tensors="pt", color="green").to(device) # or box=[x1, y1, x2, y2] for coordinates (image pixels)
+>>> inputs = processor(image, return_tensors="pt", color="green", device=device).to(device) # or box=[x1, y1, x2, y2] for coordinates (image pixels)
>>> generate_ids = model.generate(
... **inputs,
@@ -206,14 +212,15 @@ Here is an example of how to process sheet music:
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
+>>> import torch
>>> import verovio
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
->>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
+>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", use_fast=True)
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/sheet_music.png"
->>> inputs = processor(image, return_tensors="pt", format=True).to(device)
+>>> inputs = processor(image, return_tensors="pt", format=True, device=device).to(device)
>>> generate_ids = model.generate(
... **inputs,
@@ -258,6 +265,10 @@ alt="drawing" width="600"/>
[[autodoc]] GotOcr2ImageProcessor
+## GotOcr2ImageProcessorFast
+
+[[autodoc]] GotOcr2ImageProcessorFast
+
## GotOcr2Processor
[[autodoc]] GotOcr2Processor
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index ed26829010..f05a7b3b2c 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -1330,6 +1330,7 @@ else:
_import_structure["models.deit"].append("DeiTImageProcessorFast")
_import_structure["models.depth_pro"].append("DepthProImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast")
+ _import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
@@ -6526,6 +6527,7 @@ if TYPE_CHECKING:
from .models.deit import DeiTImageProcessorFast
from .models.depth_pro import DepthProImageProcessorFast
from .models.detr import DetrImageProcessorFast
+ from .models.got_ocr2 import GotOcr2ImageProcessorFast
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 4942b8f39b..180d156359 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -88,7 +88,7 @@ else:
("fuyu", ("FuyuImageProcessor",)),
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("glpn", ("GLPNImageProcessor",)),
- ("got_ocr2", ("GotOcr2ImageProcessor",)),
+ ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
("grounding-dino", ("GroundingDinoImageProcessor",)),
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("hiera", ("BitImageProcessor",)),
diff --git a/src/transformers/models/got_ocr2/__init__.py b/src/transformers/models/got_ocr2/__init__.py
index 071a7ea740..00b6ccc53f 100644
--- a/src/transformers/models/got_ocr2/__init__.py
+++ b/src/transformers/models/got_ocr2/__init__.py
@@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_got_ocr2 import *
from .image_processing_got_ocr2 import *
+ from .image_processing_got_ocr2_fast import *
from .modeling_got_ocr2 import *
from .processing_got_ocr2 import *
diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py
index 7f7a0d7ae4..d052f4a543 100644
--- a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py
@@ -1,9 +1,3 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_got_ocr2.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
#
@@ -18,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+"""Image processor class for Got-OCR-2."""
from functools import lru_cache
from typing import Dict, List, Optional, Tuple, Union
@@ -27,11 +21,9 @@ import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
- _rescale_for_pil_conversion,
convert_to_rgb,
resize,
to_channel_dimension_format,
- to_pil_image,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
@@ -142,6 +134,15 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
+ crop_to_patches (`bool`, *optional*, defaults to `False`):
+ Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
+ `preprocess` method.
+ min_patches (`int`, *optional*, defaults to 1):
+ The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
+ max_patches (`int`, *optional*, defaults to 12):
+ The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
@@ -172,6 +173,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
self,
do_resize: bool = True,
size: Dict[str, int] = None,
+ crop_to_patches: bool = False,
+ min_patches: int = 1,
+ max_patches: int = 12,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
@@ -187,6 +191,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
self.do_resize = do_resize
self.size = size
+ self.crop_to_patches = crop_to_patches
+ self.min_patches = min_patches
+ self.max_patches = max_patches
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
@@ -249,6 +256,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
+ crop_to_patches: Optional[bool] = None,
+ min_patches: Optional[int] = None,
+ max_patches: Optional[int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
@@ -274,6 +284,14 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
+ crop_to_patches (`bool`, *optional*, defaults to `self.crop_to_patches`):
+ Whether to crop the image to patches.
+ min_patches (`int`, *optional*, defaults to `self.min_patches`):
+ The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`.
+ max_patches (`int`, *optional*, defaults to `self.max_patches`):
+ The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
@@ -308,6 +326,9 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
+ crop_to_patches = crop_to_patches if crop_to_patches is not None else self.crop_to_patches
+ min_patches = min_patches if min_patches is not None else self.min_patches
+ max_patches = max_patches if max_patches is not None else self.max_patches
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
@@ -353,40 +374,52 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
- if do_resize:
+ if crop_to_patches and max_patches > 1:
images = [
- self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ self.crop_image_to_patches(
+ image,
+ min_patches=min_patches,
+ max_patches=max_patches,
+ patch_size=size,
+ data_format=input_data_format,
+ )
for image in images
]
+ num_patches = np.array([len(image) for image in images])
+ images = [image for images_list in images for image in images_list]
+ else:
+ num_patches = np.array([1] * len(images))
- if do_rescale:
- images = [
- self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
- for image in images
- ]
+ for i, image in enumerate(images):
+ if do_resize:
+ images[i] = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
- if do_normalize:
- images = [
- self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
- for image in images
- ]
+ if do_rescale:
+ images[i] = self.rescale(image=images[i], scale=rescale_factor, input_data_format=input_data_format)
- images = [
- to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
- ]
+ if do_normalize:
+ images[i] = self.normalize(
+ image=images[i],
+ mean=image_mean,
+ std=image_std,
+ input_data_format=input_data_format,
+ )
- encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+ images[i] = to_channel_dimension_format(images[i], data_format, input_channel_dim=input_data_format)
+
+ encoded_outputs = BatchFeature(
+ data={"pixel_values": images, "num_patches": num_patches}, tensor_type=return_tensors
+ )
return encoded_outputs
def crop_image_to_patches(
self,
- image: ImageInput,
+ images: np.ndarray,
min_patches: int,
max_patches: int,
use_thumbnail: bool = True,
patch_size: Union[Tuple, int, dict] = None,
- return_numpy: bool = False,
data_format: ChannelDimension = None,
):
"""
@@ -396,8 +429,8 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
Args:
- image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`):
- The image to be cropped. The image can be a PIL image, NumPy array or PyTorch tensor.
+ images (`np.ndarray`):
+ The image to be cropped.
min_patches (`int`):
The minimum number of patches to be extracted from the image.
max_patches (`int`):
@@ -406,24 +439,17 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
Whether to add a thumbnail image to the list of cropped patches.
patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
The size of the output patches.
- return_numpy (`bool`, *optional*, defaults to `False`):
- Whether to return the cropped images as NumPy arrays.
data_format (`ChannelDimension`, *optional*):
The format of the image data. If `None`, the format is inferred from the input image.
Returns:
List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
"""
- patch_size = patch_size if patch_size is not None else self.size
- patch_size = get_size_dict(patch_size, default_to_square=True)
- original_size = get_size_dict(image.size, height_width_order=False)
- do_rescale = False
- if not isinstance(image, PIL.Image.Image):
- do_rescale = _rescale_for_pil_conversion(image)
- image = to_pil_image(image, do_rescale=do_rescale)
-
+ if data_format is None:
+ data_format = infer_channel_dimension_format(images)
+ images = to_channel_dimension_format(images, ChannelDimension.FIRST, data_format)
patch_size_height, patch_size_width = patch_size["height"], patch_size["width"]
- original_height, original_width = original_size["height"], original_size["width"]
+ original_height, original_width = images.shape[-2:]
# find the closest aspect ratio to the target
num_columns, num_rows = get_optimal_tiled_canvas(
(original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
@@ -435,8 +461,12 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
num_blocks = num_columns * num_rows
# resize the image so that each patch is of patch_size
- resized_image = image.resize((target_width, target_height))
-
+ resized_image = self.resize(
+ images,
+ {"height": target_height, "width": target_width},
+ data_format=ChannelDimension.FIRST,
+ input_data_format=ChannelDimension.FIRST,
+ )
# split the image into patches
processed_images = []
for i in range(num_blocks):
@@ -449,33 +479,16 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
(row + 1) * patch_size_height,
)
# split the image
- patch_image = resized_image.crop(box)
+ patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
+ patch_image = to_channel_dimension_format(patch_image, data_format, ChannelDimension.FIRST)
processed_images.append(patch_image)
if use_thumbnail and len(processed_images) != 1:
- thumbnail_img = image.resize((patch_size_width, patch_size_height))
+ thumbnail_img = self.resize(
+ images, patch_size, data_format=data_format, input_data_format=ChannelDimension.FIRST
+ )
processed_images.append(thumbnail_img)
- if return_numpy:
- processed_images_numpy = []
- for processed_image in processed_images:
- processed_image = np.array(processed_image)
- # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
- # so we need to add it back if necessary.
- processed_image = (
- np.expand_dims(processed_image, axis=-1) if processed_image.ndim == 2 else processed_image
- )
- # The image is always in channels last format after converting from a PIL image
- if data_format is not None:
- processed_image = to_channel_dimension_format(
- processed_image, data_format, input_channel_dim=ChannelDimension.LAST
- )
- # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
- # rescale it back to the original range.
- processed_image = self.rescale(processed_image, 1 / 255) if do_rescale else processed_image
- processed_images_numpy.append(processed_image)
- processed_images = processed_images_numpy
-
return processed_images
diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py
new file mode 100644
index 0000000000..5103f73b11
--- /dev/null
+++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py
@@ -0,0 +1,257 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for Got-OCR-2."""
+
+from typing import List, Optional, Tuple, Union
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorInitKwargs,
+ DefaultFastImageProcessorPreprocessKwargs,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ add_start_docstrings,
+ is_torch_available,
+ is_torchvision_available,
+ is_torchvision_v2_available,
+)
+from .image_processing_got_ocr2 import get_optimal_tiled_canvas
+
+
+if is_torch_available():
+ import torch
+
+if is_torchvision_available():
+ if is_torchvision_v2_available():
+ from torchvision.transforms.v2 import functional as F
+ else:
+ from torchvision.transforms import functional as F
+
+
+class GotOcr2ImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
+ crop_to_patches: Optional[bool]
+ min_patches: Optional[int]
+ max_patches: Optional[int]
+
+
+class GotOcr2ImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
+ crop_to_patches: Optional[bool]
+ min_patches: Optional[int]
+ max_patches: Optional[int]
+
+
+@add_start_docstrings(
+ "Constructs a fast GotOcr2 image processor.",
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
+ """
+ crop_to_patches (`bool`, *optional*, defaults to `False`):
+ Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
+ `preprocess` method.
+ min_patches (`int`, *optional*, defaults to 1):
+ The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
+ max_patches (`int`, *optional*, defaults to 12):
+ The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
+ """,
+)
+class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ size = {"height": 384, "width": 384}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ crop_to_patches = False
+ min_patches = 1
+ max_patches = 12
+ valid_init_kwargs = GotOcr2ImageProcessorInitKwargs
+ valid_preprocess_kwargs = GotOcr2ImageProcessorPreprocessKwargs
+
+ def __init__(self, **kwargs: Unpack[GotOcr2ImageProcessorInitKwargs]):
+ super().__init__(**kwargs)
+
+ @add_start_docstrings(
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
+ """
+ crop_to_patches (`bool`, *optional*, defaults to `False`):
+ Whether to crop the image to patches. Can be overridden by the `crop_to_patches` parameter in the
+ `preprocess` method.
+ min_patches (`int`, *optional*, defaults to 1):
+ The minimum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `min_patches` parameter in the `preprocess` method.
+ max_patches (`int`, *optional*, defaults to 12):
+ The maximum number of patches to be extracted from the image. Only has an effect if `crop_to_patches` is
+ set to `True`. Can be overridden by the `max_patches` parameter in the `preprocess` method.
+ """,
+ )
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[GotOcr2ImageProcessorPreprocessKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def crop_image_to_patches(
+ self,
+ images: "torch.Tensor",
+ min_patches: int,
+ max_patches: int,
+ use_thumbnail: bool = True,
+ patch_size: Union[Tuple, int, dict] = None,
+ interpolation: Optional["F.InterpolationMode"] = None,
+ ):
+ """
+ Crop the images to patches and return a list of cropped images.
+ The number of patches and their grid arrangement are determined by the original image size,
+ the target patch size and the minimum and maximum number of patches.
+ The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
+
+ Args:
+ images (`torch.Tensor`):
+ The images to be cropped.
+ min_patches (`int`):
+ The minimum number of patches to be extracted from the image.
+ max_patches (`int`):
+ The maximum number of patches to be extracted from the image.
+ use_thumbnail (`bool`, *optional*, defaults to `True`):
+ Whether to add a thumbnail image to the list of cropped patches.
+ patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
+ The size of the output patches.
+ The format of the image data. If `None`, the format is inferred from the input image.
+
+ Returns:
+ List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
+ """
+ patch_size_height, patch_size_width = patch_size.height, patch_size.width
+ original_height, original_width = images.shape[-2:]
+ # find the closest aspect ratio to the target
+ num_columns, num_rows = get_optimal_tiled_canvas(
+ (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
+ )
+
+ # calculate the target width and height
+ target_width = patch_size_width * num_columns
+ target_height = patch_size_height * num_rows
+ num_blocks = num_columns * num_rows
+
+ # resize the image so that each patch is of patch_size
+ resized_image = self.resize(
+ images, SizeDict(height=target_height, width=target_width), interpolation=interpolation
+ )
+ # split the image into patches
+ processed_images = []
+ for i in range(num_blocks):
+ column = i % num_columns
+ row = i // num_columns
+ box = (
+ column * patch_size_width,
+ row * patch_size_height,
+ (column + 1) * patch_size_width,
+ (row + 1) * patch_size_height,
+ )
+ # split the image
+ patch_image = resized_image[..., box[1] : box[3], box[0] : box[2]]
+ processed_images.append(patch_image)
+
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = self.resize(images, patch_size, interpolation=interpolation)
+ processed_images.append(thumbnail_img)
+
+ processed_images = torch.stack(processed_images, dim=0).transpose(0, 1).contiguous()
+
+ return processed_images
+
+ def _preprocess(
+ self,
+ images: List["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ crop_to_patches: bool,
+ min_patches: int,
+ max_patches: int,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, List[float]]],
+ image_std: Optional[Union[float, List[float]]],
+ return_tensors: Optional[Union[str, TensorType]],
+ ) -> BatchFeature:
+ if crop_to_patches:
+ grouped_images, grouped_images_index = group_images_by_shape(images)
+ processed_images_grouped = {}
+ num_patches = {}
+ for shape, stacked_images in grouped_images.items():
+ stacked_images = self.crop_image_to_patches(
+ stacked_images,
+ min_patches,
+ max_patches,
+ patch_size=size,
+ interpolation=interpolation,
+ )
+ processed_images_grouped[shape] = stacked_images
+ num_patches[shape] = [stacked_images.shape[1]] * stacked_images.shape[0]
+ images = reorder_images(processed_images_grouped, grouped_images_index)
+ images = [image for images_list in images for image in images_list]
+ num_patches = reorder_images(num_patches, grouped_images_index)
+ else:
+ num_patches = [1] * len(images)
+
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(
+ data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
+ )
+
+
+__all__ = ["GotOcr2ImageProcessorFast"]
diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py
index 7fbb0d39ef..918ee2bb03 100644
--- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py
@@ -32,11 +32,7 @@ from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
-from ...utils import (
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- replace_return_docstrings,
-)
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ..auto import AutoModelForCausalLM
from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig
diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py
index e8b0d770d3..2dc500b8d6 100644
--- a/src/transformers/models/got_ocr2/modular_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py
@@ -14,35 +14,20 @@
# limitations under the License.
-from functools import lru_cache
from typing import List, Optional, Tuple, Union
-import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
-from transformers.models.blip.image_processing_blip import BlipImageProcessor
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaPreTrainedModel,
)
from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer
-from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
-from transformers.tokenization_utils_base import (
- PreTokenizedInput,
- TextInput,
-)
from ...configuration_utils import PretrainedConfig
-from ...image_processing_utils import BatchFeature, get_size_dict
-from ...image_transforms import (
- _rescale_for_pil_conversion,
- to_channel_dimension_format,
- to_pil_image,
-)
-from ...image_utils import ChannelDimension, ImageInput
from ...utils import (
add_start_docstrings_to_model_forward,
is_vision_available,
@@ -53,9 +38,7 @@ from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForCausalLM
if is_vision_available():
- import PIL
-
- from ...image_utils import load_images
+ pass
logger = logging.get_logger(__name__)
@@ -246,437 +229,6 @@ class GotOcr2Config(PretrainedConfig):
__all__ = ["GotOcr2VisionConfig", "GotOcr2Config"]
-class GotOcr2TextKwargs(TextKwargs, total=False):
- format: Optional[bool]
-
-
-class GotOcr2ImagesKwargs(ImagesKwargs, total=False):
- box: Optional[Union[List, Tuple[float, float], Tuple[float, float, float, float]]]
- color: Optional[str]
- num_image_tokens: Optional[int]
- multi_page: Optional[bool]
- crop_to_patches: Optional[bool]
- min_patches: Optional[int]
- max_patches: Optional[int]
-
-
-class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False):
- text_kwargs: GotOcr2TextKwargs
- images_kwargs: GotOcr2ImagesKwargs
- _defaults = {
- "text_kwargs": {
- "padding": False,
- "format": False,
- },
- "images_kwargs": {
- "num_image_tokens": 256,
- "multi_page": False,
- "crop_to_patches": False,
- "min_patches": 1,
- "max_patches": 12,
- },
- }
-
-
-def preprocess_box_annotation(box: Union[List, Tuple], image_size: Tuple[int, int]) -> List:
- """
- Convert box annotation to the format [x1, y1, x2, y2] in the range [0, 1000].
- """
- width, height = image_size
- if len(box) == 4:
- box[0] = int(box[0] / width * 1000)
- box[1] = int(box[1] / height * 1000)
- box[2] = int(box[2] / width * 1000)
- box[3] = int(box[3] / height * 1000)
- else:
- raise ValueError("Box must be a list or tuple of lists in the form [x1, y1, x2, y2].")
-
- return list(box)
-
-
-# Similar to image_processing_mllama.get_all_supported_aspect_ratios
-@lru_cache(maxsize=10)
-def get_all_supported_aspect_ratios(min_image_tiles: int, max_image_tiles: int) -> List[Tuple[int, int]]:
- """
- Computes all allowed aspect ratios for a given minimum and maximum number of input tiles.
-
- This function calculates all possible arrangements of tiles that can be formed
- within the constraint of the minimum and maximum number of tiles. Each arrangement is
- represented by its aspect ratio (width/height) and the corresponding tile configuration.
-
- Args:
- min_image_tiles (`int`):
- The minimum number of tiles allowed.
- max_image_tiles (`int`):
- The maximum number of tiles allowed.
-
- Returns:
- `List[Tuple[int, int]]`: A list of tuples, each tuple representing a valid (width, height)
- configuration in terms of number of tiles.
-
- Example:
- >>> get_all_supported_aspect_ratios(1, 4)
- [(1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (2, 2), (4, 1)]
-
- """
- aspect_ratios = []
- for width in range(1, max_image_tiles + 1):
- for height in range(1, max_image_tiles + 1):
- if width * height <= max_image_tiles and width * height >= min_image_tiles:
- aspect_ratios.append((width, height))
-
- aspect_ratios = sorted(aspect_ratios, key=lambda x: x[0] * x[1])
-
- return aspect_ratios
-
-
-@lru_cache(maxsize=100)
-def get_optimal_tiled_canvas(
- original_image_size: Tuple[int, int],
- target_tile_size: Tuple[int, int],
- min_image_tiles: int,
- max_image_tiles: int,
-) -> Tuple[int, int]:
- """
- Given a minimum and maximum number of tiles, find the canvas with the closest aspect ratio to the
- original image aspect ratio.
- In case of tie-breaking condition when two canvases have the same aspect ratio difference, we favor the canvas with
- more tiles, until the area covered by the tiles is more than twice the target area, in order to avoid unnecessarily
- excessive tiling.
- """
- possible_tile_arrangements = get_all_supported_aspect_ratios(min_image_tiles, max_image_tiles)
-
- original_height, original_width = original_image_size
- target_tile_height, target_tile_width = target_tile_size
- aspect_ratio = original_width / original_height
- area = original_width * original_height
-
- # find the grid with the best aspect ratio
- best_ratio_diff = float("inf")
- best_grid = (1, 1)
- for grid in possible_tile_arrangements:
- grid_aspect_ratio = grid[0] / grid[1]
- ratio_diff = abs(aspect_ratio - grid_aspect_ratio)
- if ratio_diff < best_ratio_diff:
- best_ratio_diff = ratio_diff
- best_grid = grid
- elif ratio_diff == best_ratio_diff:
- # if the aspect ratio difference is the same, we favor the grid with more patches
- # until the area covered by the patches is more than twice the original image area
- if area > 0.5 * target_tile_height * target_tile_width * grid[0] * grid[1]:
- best_grid = grid
-
- return best_grid
-
-
-class GotOcr2ImageProcessor(BlipImageProcessor):
- def crop_image_to_patches(
- self,
- image: ImageInput,
- min_patches: int,
- max_patches: int,
- use_thumbnail: bool = True,
- patch_size: Union[Tuple, int, dict] = None,
- return_numpy: bool = False,
- data_format: ChannelDimension = None,
- ):
- """
- Crop the image to patches and return a list of cropped images.
- The number of patches and their grid arrangement are determined by the original image size,
- the target patch size and the minimum and maximum number of patches.
- The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio.
-
- Args:
- image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`):
- The image to be cropped. The image can be a PIL image, NumPy array or PyTorch tensor.
- min_patches (`int`):
- The minimum number of patches to be extracted from the image.
- max_patches (`int`):
- The maximum number of patches to be extracted from the image.
- use_thumbnail (`bool`, *optional*, defaults to `True`):
- Whether to add a thumbnail image to the list of cropped patches.
- patch_size (`int`, `Tuple[int, int]`, `dict`, *optional*):
- The size of the output patches.
- return_numpy (`bool`, *optional*, defaults to `False`):
- Whether to return the cropped images as NumPy arrays.
- data_format (`ChannelDimension`, *optional*):
- The format of the image data. If `None`, the format is inferred from the input image.
-
- Returns:
- List[`PIL.Image.Image`] or List[np.ndarray]: The list of cropped images.
- """
- patch_size = patch_size if patch_size is not None else self.size
- patch_size = get_size_dict(patch_size, default_to_square=True)
- original_size = get_size_dict(image.size, height_width_order=False)
- do_rescale = False
- if not isinstance(image, PIL.Image.Image):
- do_rescale = _rescale_for_pil_conversion(image)
- image = to_pil_image(image, do_rescale=do_rescale)
-
- patch_size_height, patch_size_width = patch_size["height"], patch_size["width"]
- original_height, original_width = original_size["height"], original_size["width"]
- # find the closest aspect ratio to the target
- num_columns, num_rows = get_optimal_tiled_canvas(
- (original_height, original_width), (patch_size_height, patch_size_width), min_patches, max_patches
- )
-
- # calculate the target width and height
- target_width = patch_size_width * num_columns
- target_height = patch_size_height * num_rows
- num_blocks = num_columns * num_rows
-
- # resize the image so that each patch is of patch_size
- resized_image = image.resize((target_width, target_height))
-
- # split the image into patches
- processed_images = []
- for i in range(num_blocks):
- column = i % num_columns
- row = i // num_columns
- box = (
- column * patch_size_width,
- row * patch_size_height,
- (column + 1) * patch_size_width,
- (row + 1) * patch_size_height,
- )
- # split the image
- patch_image = resized_image.crop(box)
- processed_images.append(patch_image)
-
- if use_thumbnail and len(processed_images) != 1:
- thumbnail_img = image.resize((patch_size_width, patch_size_height))
- processed_images.append(thumbnail_img)
-
- if return_numpy:
- processed_images_numpy = []
- for processed_image in processed_images:
- processed_image = np.array(processed_image)
- # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
- # so we need to add it back if necessary.
- processed_image = (
- np.expand_dims(processed_image, axis=-1) if processed_image.ndim == 2 else processed_image
- )
- # The image is always in channels last format after converting from a PIL image
- if data_format is not None:
- processed_image = to_channel_dimension_format(
- processed_image, data_format, input_channel_dim=ChannelDimension.LAST
- )
- # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
- # rescale it back to the original range.
- processed_image = self.rescale(processed_image, 1 / 255) if do_rescale else processed_image
- processed_images_numpy.append(processed_image)
- processed_images = processed_images_numpy
-
- return processed_images
-
-
-class GotOcr2Processor(ProcessorMixin):
- r"""
- Constructs a GotOcr2 processor which wraps a [`GotOcr2ImageProcessor`] and
- [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
- tokenizer functionalities. See the [`~GotOcr2Processor.__call__`] and [`~GotOcr2Processor.decode`] for more information.
- Args:
- image_processor ([`GotOcr2ImageProcessor`], *optional*):
- The image processor is a required input.
- tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
- The tokenizer is a required input.
- chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
- in a chat into a tokenizable string.
- """
-
- attributes = ["image_processor", "tokenizer"]
- valid_kwargs = ["chat_template"]
- image_processor_class = "GotOcr2ImageProcessor"
- tokenizer_class = "PreTrainedTokenizerFast"
-
- def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
- super().__init__(image_processor, tokenizer, chat_template=chat_template)
-
- self.message_start_token = "<|im_start|>"
- self.message_end_token = "<|im_end|>"
- self.img_start_token = "
"
- self.img_end_token = ""
- self.img_pad_token = ""
- self.system_query = "system\nYou should follow the instructions carefully and explain your answers in detail."
-
- def _make_list_of_inputs(self, images, text, box, color, multi_page):
- if not isinstance(images, (list, tuple)):
- images = [images]
- if multi_page:
- logger.warning("Multi-page inference is enabled but only one image is passed.")
- images = [images]
- elif isinstance(images[0], (list, tuple)) and not multi_page:
- raise ValueError("Nested images are only supported with `multi_page` set to `True`.")
- elif not isinstance(images[0], (list, tuple)) and multi_page:
- images = [images]
-
- if isinstance(text, str):
- text = [text]
-
- if not isinstance(box[0], (list, tuple)):
- # Use the same box for all images
- box = [box for _ in range(len(images))]
- if not isinstance(color, (list, tuple)):
- color = [color for _ in range(len(images))]
-
- return images, text, box, color
-
- def __call__(
- self,
- images: Optional[ImageInput] = None,
- text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
- audio=None,
- videos=None,
- **kwargs: Unpack[GotOcr2ProcessorKwargs],
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
- and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text`
- is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and
- `crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwrags` arguments to
- GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. Both channels-first and channels-last formats are supported.
- text (`str`, `List[str]`, `List[List[str]]`):
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- format (`bool`, *optional*):
- If set, will add the format token to the query, and the model will return the OCR result with formatting.
- box (`List[float]`, `List[Tuple[float, float]]`, `List[Tuple[float, float, float, float]]`, *optional*):
- The box annotation to be added to the query. If a list of floats or a tuple of floats is provided, it
- will be interpreted as [x1, y1, x2, y2]. If a list of tuples is provided, each tuple should be in the
- form (x1, y1, x2, y2).
- color (`str`, *optional*):
- The color annotation to be added to the query. The model will return the OCR result within the box with
- the specified color.
- multi_page (`bool`, *optional*):
- If set, will enable multi-page inference. The model will return the OCR result across multiple pages.
- crop_to_patches (`bool`, *optional*):
- If set, will crop the image to patches. The model will return the OCR result upon the patch reference.
- min_patches (`int`, *optional*):
- The minimum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
- `True`.
- max_patches (`int`, *optional*):
- The maximum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
- `True`.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- """
-
- output_kwargs = self._merge_kwargs(
- GotOcr2ProcessorKwargs,
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
- **kwargs,
- )
- format_output = output_kwargs["text_kwargs"].pop("format")
- num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens")
- box = output_kwargs["images_kwargs"].pop("box", [None])
- color = output_kwargs["images_kwargs"].pop("color", None)
- multi_page = output_kwargs["images_kwargs"].pop("multi_page")
- crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches")
- min_patches = output_kwargs["images_kwargs"].pop("min_patches")
- max_patches = output_kwargs["images_kwargs"].pop("max_patches")
-
- images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page)
-
- # Load images as we need to know the image size
- images = load_images(images)
- if text is None:
- text = []
- for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
- if crop_to_patches:
- image_group = self.image_processor.crop_image_to_patches(
- image_group,
- patch_size=output_kwargs["images_kwargs"].get("size"),
- min_patches=min_patches,
- max_patches=max_patches,
- )
- images[index] = image_group
- num_images = len(image_group) if (multi_page or crop_to_patches) else 1
- if box_single[0] is not None:
- box_single = preprocess_box_annotation(box_single, image_group.size)
- query = (
- f"{f'[{color_single}] ' if color_single is not None else ''}"
- f"{str(box_single) if box_single[0] is not None else ''} "
- "OCR"
- f"{' with format' if format_output else ''}"
- f"{' across multi pages' if multi_page else ''}"
- f"{' upon the patch reference' if crop_to_patches else ''}"
- ": "
- )
- prompt = (
- self.message_start_token
- + self.system_query
- + self.message_end_token
- + self.message_start_token
- + "user\n"
- + self.img_start_token
- + self.img_pad_token * num_image_tokens * num_images
- + self.img_end_token
- + "\n"
- + query
- + self.message_end_token
- + self.message_start_token
- + "assistant\n"
- )
- text.append(prompt)
- elif crop_to_patches:
- for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
- image_group = self.image_processor.crop_image_to_patches(
- image_group,
- patch_size=output_kwargs["images_kwargs"].get("size"),
- min_patches=min_patches,
- max_patches=max_patches,
- )
- images[index] = image_group
-
- text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
- if multi_page or crop_to_patches:
- # flatten images
- images = [image for image_group in images for image in image_group]
- image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
-
- return BatchFeature(data={**text_inputs, **image_inputs})
-
- def batch_decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
- refer to the docstring of this method for more information.
- """
- return self.tokenizer.batch_decode(*args, **kwargs)
-
- def decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
- the docstring of this method for more information.
- """
- return self.tokenizer.decode(*args, **kwargs)
-
- @property
- def model_input_names(self):
- tokenizer_input_names = self.tokenizer.model_input_names
- image_processor_input_names = self.image_processor.model_input_names
- return list(tokenizer_input_names) + list(image_processor_input_names)
-
-
class GotOcr2MLPBlock(SamMLPBlock):
pass
@@ -972,8 +524,6 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
__all__ = [
"GotOcr2VisionConfig",
"GotOcr2Config",
- "GotOcr2Processor",
"GotOcr2PreTrainedModel",
"GotOcr2ForConditionalGeneration",
- "GotOcr2ImageProcessor",
]
diff --git a/src/transformers/models/got_ocr2/processing_got_ocr2.py b/src/transformers/models/got_ocr2/processing_got_ocr2.py
index 636db765f9..398ec36c9e 100644
--- a/src/transformers/models/got_ocr2/processing_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/processing_got_ocr2.py
@@ -1,9 +1,3 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_got_ocr2.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
#
@@ -22,6 +16,8 @@
from typing import List, Optional, Tuple, Union
+import numpy as np
+
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
@@ -100,7 +96,7 @@ class GotOcr2Processor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
- image_processor_class = "GotOcr2ImageProcessor"
+ image_processor_class = "AutoImageProcessor"
tokenizer_class = "PreTrainedTokenizerFast"
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
@@ -205,28 +201,29 @@ class GotOcr2Processor(ProcessorMixin):
box = output_kwargs["images_kwargs"].pop("box", [None])
color = output_kwargs["images_kwargs"].pop("color", None)
multi_page = output_kwargs["images_kwargs"].pop("multi_page")
- crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches")
- min_patches = output_kwargs["images_kwargs"].pop("min_patches")
- max_patches = output_kwargs["images_kwargs"].pop("max_patches")
+ crop_to_patches = output_kwargs["images_kwargs"].get("crop_to_patches")
images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page)
-
+ if multi_page:
+ # save the number of pages per batch
+ num_pages_per_batch = [len(image_group) for image_group in images]
+ # flatten the list of images
+ images = [image for image_group in images for image in image_group]
+ else:
+ num_pages_per_batch = [1 for _ in range(len(images))]
# Load images as we need to know the image size
images = load_images(images)
+ image_sizes = [image.size for image in images]
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ num_patches_array = image_inputs.pop("num_patches")
if text is None:
text = []
- for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
- if crop_to_patches:
- image_group = self.image_processor.crop_image_to_patches(
- image_group,
- patch_size=output_kwargs["images_kwargs"].get("size"),
- min_patches=min_patches,
- max_patches=max_patches,
- )
- images[index] = image_group
- num_images = len(image_group) if (multi_page or crop_to_patches) else 1
+ patch_indices = np.cumsum(num_pages_per_batch)
+ for index, (num_pages, box_single, color_single) in enumerate(zip(num_pages_per_batch, box, color)):
+ current_patch_index = patch_indices[index - 1] if index > 0 else 0
+ num_patches = sum(num_patches_array[current_patch_index : current_patch_index + num_pages])
if box_single[0] is not None:
- box_single = preprocess_box_annotation(box_single, image_group.size)
+ box_single = preprocess_box_annotation(box_single, image_sizes[index])
query = (
f"{f'[{color_single}] ' if color_single is not None else ''}"
f"{str(box_single) if box_single[0] is not None else ''} "
@@ -243,7 +240,7 @@ class GotOcr2Processor(ProcessorMixin):
+ self.message_start_token
+ "user\n"
+ self.img_start_token
- + self.img_pad_token * num_image_tokens * num_images
+ + self.img_pad_token * num_image_tokens * num_patches
+ self.img_end_token
+ "\n"
+ query
@@ -252,22 +249,8 @@ class GotOcr2Processor(ProcessorMixin):
+ "assistant\n"
)
text.append(prompt)
- elif crop_to_patches:
- for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)):
- image_group = self.image_processor.crop_image_to_patches(
- image_group,
- patch_size=output_kwargs["images_kwargs"].get("size"),
- min_patches=min_patches,
- max_patches=max_patches,
- )
- images[index] = image_group
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
- if multi_page or crop_to_patches:
- # flatten images
- images = [image for image_group in images for image in image_group]
- image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
-
return BatchFeature(data={**text_inputs, **image_inputs})
def batch_decode(self, *args, **kwargs):
diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py
index f393a8f126..de62c4ae7c 100644
--- a/src/transformers/utils/dummy_torchvision_objects.py
+++ b/src/transformers/utils/dummy_torchvision_objects.py
@@ -58,6 +58,13 @@ class DetrImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"])
+class GotOcr2ImageProcessorFast(metaclass=DummyObject):
+ _backends = ["torchvision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torchvision"])
+
+
class LlavaImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
diff --git a/tests/models/got_ocr2/test_image_processing_got_ocr2.py b/tests/models/got_ocr2/test_image_processing_got_ocr2.py
index c4e75feee6..93cd347dea 100644
--- a/tests/models/got_ocr2/test_image_processing_got_ocr2.py
+++ b/tests/models/got_ocr2/test_image_processing_got_ocr2.py
@@ -16,15 +16,22 @@
import unittest
+from transformers.image_utils import SizeDict
from transformers.testing_utils import require_torch, require_vision
-from transformers.utils import is_vision_available
+from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
+if is_torch_available():
+ import torch
+
if is_vision_available():
from transformers import GotOcr2ImageProcessor
+ if is_torchvision_available():
+ from transformers import GotOcr2ImageProcessorFast
+
class GotOcr2ImageProcessingTester(unittest.TestCase):
def __init__(
@@ -89,6 +96,7 @@ class GotOcr2ImageProcessingTester(unittest.TestCase):
@require_vision
class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = GotOcr2ImageProcessor if is_vision_available() else None
+ fast_image_processing_class = GotOcr2ImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@@ -99,17 +107,72 @@ class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
- image_processor = self.image_processing_class(**self.image_processor_dict)
- self.assertTrue(hasattr(image_processor, "do_resize"))
- self.assertTrue(hasattr(image_processor, "size"))
- self.assertTrue(hasattr(image_processor, "do_normalize"))
- self.assertTrue(hasattr(image_processor, "image_mean"))
- self.assertTrue(hasattr(image_processor, "image_std"))
- self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
+ for image_processing_class in self.image_processor_list:
+ image_processor = image_processing_class(**self.image_processor_dict)
+ self.assertTrue(hasattr(image_processor, "do_resize"))
+ self.assertTrue(hasattr(image_processor, "size"))
+ self.assertTrue(hasattr(image_processor, "do_normalize"))
+ self.assertTrue(hasattr(image_processor, "image_mean"))
+ self.assertTrue(hasattr(image_processor, "image_std"))
+ self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
+
+ def test_slow_fast_equivalence_crop_to_patches(self):
+ dummy_image = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)[0]
+
+ image_processor_slow = self.image_processing_class(**self.image_processor_dict, crop_to_patches=True)
+ image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, crop_to_patches=True)
+
+ encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
+ encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
+
+ torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
+ self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
+ self.assertLessEqual(
+ torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
+ )
+
+ def test_slow_fast_equivalence_batched_crop_to_patches(self):
+ # Prepare image inputs so that we have two groups of images with equal resolution with a group of images with
+ # different resolutions in between
+ dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
+ dummy_images += self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
+ dummy_images += self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
+
+ image_processor_slow = self.image_processing_class(**self.image_processor_dict, crop_to_patches=True)
+ image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, crop_to_patches=True)
+
+ encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
+ encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
+
+ torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
+ self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
+ self.assertLessEqual(
+ torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
+ )
def test_crop_to_patches(self):
- image_processor = self.image_processing_class(**self.image_processor_dict)
- image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)[0]
- processed_images = image_processor.crop_image_to_patches(image, 1, 6, use_thumbnail=True)
+ # test slow image processor
+ image_processor = self.image_processor_list[0](**self.image_processor_dict)
+ image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)[0]
+ processed_images = image_processor.crop_image_to_patches(
+ image,
+ min_patches=1,
+ max_patches=6,
+ use_thumbnail=True,
+ patch_size={"height": 20, "width": 20},
+ )
self.assertEqual(len(processed_images), 5)
- self.assertEqual(processed_images[0].size, (20, 20))
+ self.assertEqual(processed_images[0].shape[:2], (20, 20))
+
+ # test fast image processor (process batch)
+ image_processor = self.image_processor_list[1](**self.image_processor_dict)
+ image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)[0]
+ processed_images = image_processor.crop_image_to_patches(
+ image.unsqueeze(0),
+ min_patches=1,
+ max_patches=6,
+ use_thumbnail=True,
+ patch_size=SizeDict(height=20, width=20),
+ )
+ self.assertEqual(len(processed_images[0]), 5)
+ self.assertEqual(processed_images.shape[-2:], (20, 20))