[DETR] Add num_channels attribute (#18714)
* Add num_channels attribute * Fix code quality Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -416,6 +416,26 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
def test_greyscale_images(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# use greyscale pixel values
|
||||
inputs_dict["pixel_values"] = floats_tensor(
|
||||
[self.model_tester.batch_size, 1, self.model_tester.min_size, self.model_tester.max_size]
|
||||
)
|
||||
|
||||
# let's set num_channels to 1
|
||||
config.num_channels = 1
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user