Fix Automatic Download of Pretrained Weights in DETR (#17712)
* added use_backbone_pretrained * style fixes * update * Update detr.mdx * Update detr.mdx * Update detr.mdx * update using doc py * Update detr.mdx * Update src/transformers/models/detr/configuration_detr.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -113,6 +113,28 @@ Tips:
|
||||
- The size of the images will determine the amount of memory being used, and will thus determine the `batch_size`.
|
||||
It is advised to use a batch size of 2 per GPU. See [this Github thread](https://github.com/facebookresearch/detr/issues/150) for more info.
|
||||
|
||||
There are three ways to instantiate a DETR model (depending on what you prefer):
|
||||
|
||||
Option 1: Instantiate DETR with pre-trained weights for entire model
|
||||
```py
|
||||
>>> from transformers import DetrForObjectDetection
|
||||
|
||||
>>> model = DetrForObjectDetection.from_pretrained("facebook/resnet-50")
|
||||
```
|
||||
|
||||
Option 2: Instantiate DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone
|
||||
```py
|
||||
>>> from transformers import DetrConfig, DetrForObjectDetection
|
||||
|
||||
>>> config = DetrConfig()
|
||||
>>> model = DetrForObjectDetection(config)
|
||||
```
|
||||
Option 3: Instantiate DETR with randomly initialized weights for backbone + Transformer
|
||||
```py
|
||||
>>> config = DetrConfig(use_pretrained_backbone=False)
|
||||
>>> model = DetrForObjectDetection(config)
|
||||
```
|
||||
|
||||
As a summary, consider the following table:
|
||||
|
||||
| Task | Object detection | Instance segmentation | Panoptic segmentation |
|
||||
@@ -166,4 +188,4 @@ mean Average Precision (mAP) and Panoptic Quality (PQ). The latter objects are i
|
||||
## DetrForSegmentation
|
||||
|
||||
[[autodoc]] DetrForSegmentation
|
||||
- forward
|
||||
- forward
|
||||
|
||||
@@ -82,6 +82,8 @@ class DetrConfig(PretrainedConfig):
|
||||
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
|
||||
list of all available models, see [this
|
||||
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
|
||||
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use pretrained weights for the backbone.
|
||||
dilation (`bool`, *optional*, defaults to `False`):
|
||||
Whether to replace stride with dilation in the last convolutional block (DC5).
|
||||
class_cost (`float`, *optional*, defaults to 1):
|
||||
@@ -147,6 +149,7 @@ class DetrConfig(PretrainedConfig):
|
||||
auxiliary_loss=False,
|
||||
position_embedding_type="sine",
|
||||
backbone="resnet50",
|
||||
use_pretrained_backbone=True,
|
||||
dilation=False,
|
||||
class_cost=1,
|
||||
bbox_cost=5,
|
||||
@@ -180,6 +183,7 @@ class DetrConfig(PretrainedConfig):
|
||||
self.auxiliary_loss = auxiliary_loss
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self.backbone = backbone
|
||||
self.use_pretrained_backbone = use_pretrained_backbone
|
||||
self.dilation = dilation
|
||||
# Hungarian matcher
|
||||
self.class_cost = class_cost
|
||||
|
||||
@@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, dilation: bool):
|
||||
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool):
|
||||
super().__init__()
|
||||
|
||||
kwargs = {}
|
||||
@@ -335,7 +335,9 @@ class DetrTimmConvEncoder(nn.Module):
|
||||
|
||||
requires_backends(self, ["timm"])
|
||||
|
||||
backbone = create_model(name, pretrained=True, features_only=True, out_indices=(1, 2, 3, 4), **kwargs)
|
||||
backbone = create_model(
|
||||
name, pretrained=use_pretrained_backbone, features_only=True, out_indices=(1, 2, 3, 4), **kwargs
|
||||
)
|
||||
# replace batch norm by frozen batch norm
|
||||
with torch.no_grad():
|
||||
replace_batch_norm(backbone)
|
||||
@@ -1177,7 +1179,7 @@ class DetrModel(DetrPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
# Create backbone + positional encoding
|
||||
backbone = DetrTimmConvEncoder(config.backbone, config.dilation)
|
||||
backbone = DetrTimmConvEncoder(config.backbone, config.dilation, config.use_pretrained_backbone)
|
||||
position_embeddings = build_position_encoding(config)
|
||||
self.backbone = DetrConvModel(backbone, position_embeddings)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user