Fast image processor (#28847)
* Draft fast image processors * Draft working fast version * py3.8 compatible cache * Enable loading fast image processors through auto * Tidy up; rescale behaviour based on input type * Enable tests for fast image processors * Smarter rescaling * Don't default to Fast * Safer imports * Add necessary Pillow requirement * Woops * Add AutoImageProcessor test * Fix up * Fix test for imagegpt * Fix test * Review comments * Add warning for TF and JAX input types * Rearrange * Return transforms * NumpyToTensor transformation * Rebase - include changes from upstream in ImageProcessingMixin * Safe typing * Fix up * convert mean/std to tesnor to rescale * Don't store transforms in state * Fix up * Update src/transformers/image_processing_utils_fast.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/auto/image_processing_auto.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/auto/image_processing_auto.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/auto/image_processing_auto.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Warn if fast image processor available * Update src/transformers/models/vit/image_processing_vit_fast.py * Transpose incoming numpy images to be in CHW format * Update mapping names based on packages, auto set fast to None * Fix up * Fix * Add AutoImageProcessor.from_pretrained(checkpoint, use_fast=True) test * Update src/transformers/models/vit/image_processing_vit_fast.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Add equivalence and speed tests * Fix up --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
63
src/transformers/image_processing_utils_fast.py
Normal file
63
src/transformers/image_processing_utils_fast.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .image_processing_utils import BaseImageProcessor
|
||||
from .utils.import_utils import is_torchvision_available
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SizeDict:
|
||||
"""
|
||||
Hashable dictionary to store image size information.
|
||||
"""
|
||||
|
||||
height: int = None
|
||||
width: int = None
|
||||
longest_edge: int = None
|
||||
shortest_edge: int = None
|
||||
max_height: int = None
|
||||
max_width: int = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
if hasattr(self, key):
|
||||
return getattr(self, key)
|
||||
raise KeyError(f"Key {key} not found in SizeDict.")
|
||||
|
||||
|
||||
class BaseImageProcessorFast(BaseImageProcessor):
|
||||
_transform_params = None
|
||||
|
||||
def _build_transforms(self, **kwargs) -> "Compose":
|
||||
"""
|
||||
Given the input settings e.g. do_resize, build the image transforms.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _validate_params(self, **kwargs) -> None:
|
||||
for k, v in kwargs.items():
|
||||
if k not in self._transform_params:
|
||||
raise ValueError(f"Invalid transform parameter {k}={v}.")
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_transforms(self, **kwargs) -> "Compose":
|
||||
self._validate_params(**kwargs)
|
||||
return self._build_transforms(**kwargs)
|
||||
Reference in New Issue
Block a user