Enable passing number of channels when inferring data format (#25412)
This commit is contained in:
@@ -578,6 +578,10 @@ class UtilFunctionTester(unittest.TestCase):
|
||||
with pytest.raises(ValueError):
|
||||
infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))
|
||||
|
||||
# But if we explicitly set one of the number of channels to 50 it works
|
||||
inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50)
|
||||
self.assertEqual(inferred_dim, ChannelDimension.LAST)
|
||||
|
||||
# Test we correctly identify the channel dimension
|
||||
image = np.random.randint(0, 256, (3, 4, 5))
|
||||
inferred_dim = infer_channel_dimension_format(image)
|
||||
|
||||
Reference in New Issue
Block a user