Uses Collection in transformers.image_transforms.normalize (#36301)

* Uses Collection instead of Sequence in transformers.image_transforms.normalize

* Uses collections.abc.Collection in lieu of deprecated typing one
This commit is contained in:
CalOmnie
2025-02-21 18:38:41 +01:00
committed by GitHub
parent 7c5bd24ffa
commit 547911e727
2 changed files with 10 additions and 9 deletions

View File

@@ -14,8 +14,9 @@
# limitations under the License.
import warnings
from collections.abc import Collection
from math import ceil
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
@@ -389,8 +390,8 @@ def resize(
def normalize(
image: np.ndarray,
mean: Union[float, Sequence[float]],
std: Union[float, Sequence[float]],
mean: Union[float, Collection[float]],
std: Union[float, Collection[float]],
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
@@ -402,9 +403,9 @@ def normalize(
Args:
image (`np.ndarray`):
The image to normalize.
mean (`float` or `Sequence[float]`):
mean (`float` or `Collection[float]`):
The mean to use for normalization.
std (`float` or `Sequence[float]`):
std (`float` or `Collection[float]`):
The standard deviation to use for normalization.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If unset, will use the inferred format from the input.
@@ -425,14 +426,14 @@ def normalize(
if not np.issubdtype(image.dtype, np.floating):
image = image.astype(np.float32)
if isinstance(mean, Sequence):
if isinstance(mean, Collection):
if len(mean) != num_channels:
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
else:
mean = [mean] * num_channels
mean = np.array(mean, dtype=image.dtype)
if isinstance(std, Sequence):
if isinstance(std, Collection):
if len(std) != num_channels:
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
else:

View File

@@ -326,8 +326,8 @@ class ImageTransformsTester(unittest.TestCase):
# Test float16 image input keeps float16 dtype
image = np.random.randint(0, 256, (224, 224, 3)).astype(np.float16) / 255
mean = (0.5, 0.6, 0.7)
std = (0.1, 0.2, 0.3)
mean = np.array((0.5, 0.6, 0.7))
std = np.array((0.1, 0.2, 0.3))
# The mean and std are cast to match the dtype of the input image
cast_mean = np.array(mean, dtype=np.float16)