Enable passing number of channels when inferring data format (#25412)
This commit is contained in:
@@ -144,17 +144,24 @@ def to_numpy_array(img) -> np.ndarray:
|
|||||||
return to_numpy(img)
|
return to_numpy(img)
|
||||||
|
|
||||||
|
|
||||||
def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
|
def infer_channel_dimension_format(
|
||||||
|
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
|
||||||
|
) -> ChannelDimension:
|
||||||
"""
|
"""
|
||||||
Infers the channel dimension format of `image`.
|
Infers the channel dimension format of `image`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image (`np.ndarray`):
|
image (`np.ndarray`):
|
||||||
The image to infer the channel dimension of.
|
The image to infer the channel dimension of.
|
||||||
|
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
|
||||||
|
The number of channels of the image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The channel dimension of the image.
|
The channel dimension of the image.
|
||||||
"""
|
"""
|
||||||
|
num_channels = num_channels if num_channels is not None else (1, 3)
|
||||||
|
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
|
||||||
|
|
||||||
if image.ndim == 3:
|
if image.ndim == 3:
|
||||||
first_dim, last_dim = 0, 2
|
first_dim, last_dim = 0, 2
|
||||||
elif image.ndim == 4:
|
elif image.ndim == 4:
|
||||||
@@ -162,9 +169,9 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
|
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
|
||||||
|
|
||||||
if image.shape[first_dim] in (1, 3):
|
if image.shape[first_dim] in num_channels:
|
||||||
return ChannelDimension.FIRST
|
return ChannelDimension.FIRST
|
||||||
elif image.shape[last_dim] in (1, 3):
|
elif image.shape[last_dim] in num_channels:
|
||||||
return ChannelDimension.LAST
|
return ChannelDimension.LAST
|
||||||
raise ValueError("Unable to infer channel dimension format")
|
raise ValueError("Unable to infer channel dimension format")
|
||||||
|
|
||||||
|
|||||||
@@ -578,6 +578,10 @@ class UtilFunctionTester(unittest.TestCase):
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))
|
infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))
|
||||||
|
|
||||||
|
# But if we explicitly set one of the number of channels to 50 it works
|
||||||
|
inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50)
|
||||||
|
self.assertEqual(inferred_dim, ChannelDimension.LAST)
|
||||||
|
|
||||||
# Test we correctly identify the channel dimension
|
# Test we correctly identify the channel dimension
|
||||||
image = np.random.randint(0, 256, (3, 4, 5))
|
image = np.random.randint(0, 256, (3, 4, 5))
|
||||||
inferred_dim = infer_channel_dimension_format(image)
|
inferred_dim = infer_channel_dimension_format(image)
|
||||||
|
|||||||
Reference in New Issue
Block a user