diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index eaaadbf242..aaadcb4458 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -14,8 +14,9 @@ # limitations under the License. import warnings +from collections.abc import Collection from math import ceil -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -389,8 +390,8 @@ def resize( def normalize( image: np.ndarray, - mean: Union[float, Sequence[float]], - std: Union[float, Sequence[float]], + mean: Union[float, Collection[float]], + std: Union[float, Collection[float]], data_format: Optional[ChannelDimension] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: @@ -402,9 +403,9 @@ def normalize( Args: image (`np.ndarray`): The image to normalize. - mean (`float` or `Sequence[float]`): + mean (`float` or `Collection[float]`): The mean to use for normalization. - std (`float` or `Sequence[float]`): + std (`float` or `Collection[float]`): The standard deviation to use for normalization. data_format (`ChannelDimension`, *optional*): The channel dimension format of the output image. If unset, will use the inferred format from the input. @@ -425,14 +426,14 @@ def normalize( if not np.issubdtype(image.dtype, np.floating): image = image.astype(np.float32) - if isinstance(mean, Sequence): + if isinstance(mean, Collection): if len(mean) != num_channels: raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") else: mean = [mean] * num_channels mean = np.array(mean, dtype=image.dtype) - if isinstance(std, Sequence): + if isinstance(std, Collection): if len(std) != num_channels: raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}") else: diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 25775d787e..560ea6a36b 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -326,8 +326,8 @@ class ImageTransformsTester(unittest.TestCase): # Test float16 image input keeps float16 dtype image = np.random.randint(0, 256, (224, 224, 3)).astype(np.float16) / 255 - mean = (0.5, 0.6, 0.7) - std = (0.1, 0.2, 0.3) + mean = np.array((0.5, 0.6, 0.7)) + std = np.array((0.1, 0.2, 0.3)) # The mean and std are cast to match the dtype of the input image cast_mean = np.array(mean, dtype=np.float16)