[DETR] Add num_channels attribute (#18714)
* Add num_channels attribute * Fix code quality Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -42,8 +42,9 @@ class DetrConfig(PretrainedConfig):
|
|||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
num_channels (`int`, *optional*, defaults to 3):
|
||||||
|
The number of input channels.
|
||||||
num_queries (`int`, *optional*, defaults to 100):
|
num_queries (`int`, *optional*, defaults to 100):
|
||||||
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
|
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
|
||||||
detect in a single image. For COCO, we recommend 100 queries.
|
detect in a single image. For COCO, we recommend 100 queries.
|
||||||
@@ -132,6 +133,7 @@ class DetrConfig(PretrainedConfig):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
num_channels=3,
|
||||||
num_queries=100,
|
num_queries=100,
|
||||||
max_position_embeddings=1024,
|
max_position_embeddings=1024,
|
||||||
encoder_layers=6,
|
encoder_layers=6,
|
||||||
@@ -167,6 +169,7 @@ class DetrConfig(PretrainedConfig):
|
|||||||
eos_coefficient=0.1,
|
eos_coefficient=0.1,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
self.num_channels = num_channels
|
||||||
self.num_queries = num_queries
|
self.num_queries = num_queries
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|||||||
@@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool):
|
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
@@ -336,7 +336,12 @@ class DetrTimmConvEncoder(nn.Module):
|
|||||||
requires_backends(self, ["timm"])
|
requires_backends(self, ["timm"])
|
||||||
|
|
||||||
backbone = create_model(
|
backbone = create_model(
|
||||||
name, pretrained=use_pretrained_backbone, features_only=True, out_indices=(1, 2, 3, 4), **kwargs
|
name,
|
||||||
|
pretrained=use_pretrained_backbone,
|
||||||
|
features_only=True,
|
||||||
|
out_indices=(1, 2, 3, 4),
|
||||||
|
in_chans=num_channels,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# replace batch norm by frozen batch norm
|
# replace batch norm by frozen batch norm
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -1179,7 +1184,9 @@ class DetrModel(DetrPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Create backbone + positional encoding
|
# Create backbone + positional encoding
|
||||||
backbone = DetrTimmConvEncoder(config.backbone, config.dilation, config.use_pretrained_backbone)
|
backbone = DetrTimmConvEncoder(
|
||||||
|
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
|
||||||
|
)
|
||||||
position_embeddings = build_position_encoding(config)
|
position_embeddings = build_position_encoding(config)
|
||||||
self.backbone = DetrConvModel(backbone, position_embeddings)
|
self.backbone = DetrConvModel(backbone, position_embeddings)
|
||||||
|
|
||||||
|
|||||||
@@ -416,6 +416,26 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(outputs)
|
self.assertTrue(outputs)
|
||||||
|
|
||||||
|
def test_greyscale_images(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# use greyscale pixel values
|
||||||
|
inputs_dict["pixel_values"] = floats_tensor(
|
||||||
|
[self.model_tester.batch_size, 1, self.model_tester.min_size, self.model_tester.max_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
# let's set num_channels to 1
|
||||||
|
config.num_channels = 1
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
self.assertTrue(outputs)
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user