Enable passing number of channels when inferring data format (#25412)

This commit is contained in:
amyeroberts
2023-08-09 17:41:21 +01:00
committed by GitHub
parent cb3c821cb7
commit 944ddce8bf
2 changed files with 14 additions and 3 deletions

View File

@@ -578,6 +578,10 @@ class UtilFunctionTester(unittest.TestCase):
with pytest.raises(ValueError):
infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))
# But if we explicitly set one of the number of channels to 50 it works
inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50)
self.assertEqual(inferred_dim, ChannelDimension.LAST)
# Test we correctly identify the channel dimension
image = np.random.randint(0, 256, (3, 4, 5))
inferred_dim = infer_channel_dimension_format(image)