[DETR] Add num_channels attribute (#18714)

* Add num_channels attribute

* Fix code quality

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-08-31 18:04:42 +02:00
committed by GitHub
parent 811c4c9f79
commit 3b6943e7a3
3 changed files with 34 additions and 4 deletions

View File

@@ -416,6 +416,26 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
self.assertTrue(outputs)
def test_greyscale_images(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# use greyscale pixel values
inputs_dict["pixel_values"] = floats_tensor(
[self.model_tester.batch_size, 1, self.model_tester.min_size, self.model_tester.max_size]
)
# let's set num_channels to 1
config.num_channels = 1
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertTrue(outputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()