Uses Collection in transformers.image_transforms.normalize (#36301)
* Uses Collection instead of Sequence in transformers.image_transforms.normalize * Uses collections.abc.Collection in lieu of deprecated typing one
This commit is contained in:
@@ -14,8 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Collection
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -389,8 +390,8 @@ def resize(
|
|||||||
|
|
||||||
def normalize(
|
def normalize(
|
||||||
image: np.ndarray,
|
image: np.ndarray,
|
||||||
mean: Union[float, Sequence[float]],
|
mean: Union[float, Collection[float]],
|
||||||
std: Union[float, Sequence[float]],
|
std: Union[float, Collection[float]],
|
||||||
data_format: Optional[ChannelDimension] = None,
|
data_format: Optional[ChannelDimension] = None,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
@@ -402,9 +403,9 @@ def normalize(
|
|||||||
Args:
|
Args:
|
||||||
image (`np.ndarray`):
|
image (`np.ndarray`):
|
||||||
The image to normalize.
|
The image to normalize.
|
||||||
mean (`float` or `Sequence[float]`):
|
mean (`float` or `Collection[float]`):
|
||||||
The mean to use for normalization.
|
The mean to use for normalization.
|
||||||
std (`float` or `Sequence[float]`):
|
std (`float` or `Collection[float]`):
|
||||||
The standard deviation to use for normalization.
|
The standard deviation to use for normalization.
|
||||||
data_format (`ChannelDimension`, *optional*):
|
data_format (`ChannelDimension`, *optional*):
|
||||||
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
||||||
@@ -425,14 +426,14 @@ def normalize(
|
|||||||
if not np.issubdtype(image.dtype, np.floating):
|
if not np.issubdtype(image.dtype, np.floating):
|
||||||
image = image.astype(np.float32)
|
image = image.astype(np.float32)
|
||||||
|
|
||||||
if isinstance(mean, Sequence):
|
if isinstance(mean, Collection):
|
||||||
if len(mean) != num_channels:
|
if len(mean) != num_channels:
|
||||||
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
|
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
|
||||||
else:
|
else:
|
||||||
mean = [mean] * num_channels
|
mean = [mean] * num_channels
|
||||||
mean = np.array(mean, dtype=image.dtype)
|
mean = np.array(mean, dtype=image.dtype)
|
||||||
|
|
||||||
if isinstance(std, Sequence):
|
if isinstance(std, Collection):
|
||||||
if len(std) != num_channels:
|
if len(std) != num_channels:
|
||||||
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
|
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -326,8 +326,8 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
|
|
||||||
# Test float16 image input keeps float16 dtype
|
# Test float16 image input keeps float16 dtype
|
||||||
image = np.random.randint(0, 256, (224, 224, 3)).astype(np.float16) / 255
|
image = np.random.randint(0, 256, (224, 224, 3)).astype(np.float16) / 255
|
||||||
mean = (0.5, 0.6, 0.7)
|
mean = np.array((0.5, 0.6, 0.7))
|
||||||
std = (0.1, 0.2, 0.3)
|
std = np.array((0.1, 0.2, 0.3))
|
||||||
|
|
||||||
# The mean and std are cast to match the dtype of the input image
|
# The mean and std are cast to match the dtype of the input image
|
||||||
cast_mean = np.array(mean, dtype=np.float16)
|
cast_mean = np.array(mean, dtype=np.float16)
|
||||||
|
|||||||
Reference in New Issue
Block a user