Fix Automatic Download of Pretrained Weights in DETR (#17712)
* added use_backbone_pretrained * style fixes * update * Update detr.mdx * Update detr.mdx * Update detr.mdx * update using doc py * Update detr.mdx * Update src/transformers/models/detr/configuration_detr.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -113,6 +113,28 @@ Tips:
|
|||||||
- The size of the images will determine the amount of memory being used, and will thus determine the `batch_size`.
|
- The size of the images will determine the amount of memory being used, and will thus determine the `batch_size`.
|
||||||
It is advised to use a batch size of 2 per GPU. See [this Github thread](https://github.com/facebookresearch/detr/issues/150) for more info.
|
It is advised to use a batch size of 2 per GPU. See [this Github thread](https://github.com/facebookresearch/detr/issues/150) for more info.
|
||||||
|
|
||||||
|
There are three ways to instantiate a DETR model (depending on what you prefer):
|
||||||
|
|
||||||
|
Option 1: Instantiate DETR with pre-trained weights for entire model
|
||||||
|
```py
|
||||||
|
>>> from transformers import DetrForObjectDetection
|
||||||
|
|
||||||
|
>>> model = DetrForObjectDetection.from_pretrained("facebook/resnet-50")
|
||||||
|
```
|
||||||
|
|
||||||
|
Option 2: Instantiate DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone
|
||||||
|
```py
|
||||||
|
>>> from transformers import DetrConfig, DetrForObjectDetection
|
||||||
|
|
||||||
|
>>> config = DetrConfig()
|
||||||
|
>>> model = DetrForObjectDetection(config)
|
||||||
|
```
|
||||||
|
Option 3: Instantiate DETR with randomly initialized weights for backbone + Transformer
|
||||||
|
```py
|
||||||
|
>>> config = DetrConfig(use_pretrained_backbone=False)
|
||||||
|
>>> model = DetrForObjectDetection(config)
|
||||||
|
```
|
||||||
|
|
||||||
As a summary, consider the following table:
|
As a summary, consider the following table:
|
||||||
|
|
||||||
| Task | Object detection | Instance segmentation | Panoptic segmentation |
|
| Task | Object detection | Instance segmentation | Panoptic segmentation |
|
||||||
|
|||||||
@@ -82,6 +82,8 @@ class DetrConfig(PretrainedConfig):
|
|||||||
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
|
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
|
||||||
list of all available models, see [this
|
list of all available models, see [this
|
||||||
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
|
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
|
||||||
|
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use pretrained weights for the backbone.
|
||||||
dilation (`bool`, *optional*, defaults to `False`):
|
dilation (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to replace stride with dilation in the last convolutional block (DC5).
|
Whether to replace stride with dilation in the last convolutional block (DC5).
|
||||||
class_cost (`float`, *optional*, defaults to 1):
|
class_cost (`float`, *optional*, defaults to 1):
|
||||||
@@ -147,6 +149,7 @@ class DetrConfig(PretrainedConfig):
|
|||||||
auxiliary_loss=False,
|
auxiliary_loss=False,
|
||||||
position_embedding_type="sine",
|
position_embedding_type="sine",
|
||||||
backbone="resnet50",
|
backbone="resnet50",
|
||||||
|
use_pretrained_backbone=True,
|
||||||
dilation=False,
|
dilation=False,
|
||||||
class_cost=1,
|
class_cost=1,
|
||||||
bbox_cost=5,
|
bbox_cost=5,
|
||||||
@@ -180,6 +183,7 @@ class DetrConfig(PretrainedConfig):
|
|||||||
self.auxiliary_loss = auxiliary_loss
|
self.auxiliary_loss = auxiliary_loss
|
||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
# Hungarian matcher
|
# Hungarian matcher
|
||||||
self.class_cost = class_cost
|
self.class_cost = class_cost
|
||||||
|
|||||||
@@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str, dilation: bool):
|
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
@@ -335,7 +335,9 @@ class DetrTimmConvEncoder(nn.Module):
|
|||||||
|
|
||||||
requires_backends(self, ["timm"])
|
requires_backends(self, ["timm"])
|
||||||
|
|
||||||
backbone = create_model(name, pretrained=True, features_only=True, out_indices=(1, 2, 3, 4), **kwargs)
|
backbone = create_model(
|
||||||
|
name, pretrained=use_pretrained_backbone, features_only=True, out_indices=(1, 2, 3, 4), **kwargs
|
||||||
|
)
|
||||||
# replace batch norm by frozen batch norm
|
# replace batch norm by frozen batch norm
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
replace_batch_norm(backbone)
|
replace_batch_norm(backbone)
|
||||||
@@ -1177,7 +1179,7 @@ class DetrModel(DetrPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Create backbone + positional encoding
|
# Create backbone + positional encoding
|
||||||
backbone = DetrTimmConvEncoder(config.backbone, config.dilation)
|
backbone = DetrTimmConvEncoder(config.backbone, config.dilation, config.use_pretrained_backbone)
|
||||||
position_embeddings = build_position_encoding(config)
|
position_embeddings = build_position_encoding(config)
|
||||||
self.backbone = DetrConvModel(backbone, position_embeddings)
|
self.backbone = DetrConvModel(backbone, position_embeddings)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user