[DETR and friends] Remove is_timm_available (#21814)
* First draft * Fix to_dict * Improve conversion script * Update config * Remove timm dependency * Fix dummies * Fix typo, add integration test * Upload 101 model as well * Remove timm dummies * Fix style --------- Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -866,52 +866,6 @@ else:
|
|||||||
_import_structure["models.vit_hybrid"].extend(["ViTHybridImageProcessor"])
|
_import_structure["models.vit_hybrid"].extend(["ViTHybridImageProcessor"])
|
||||||
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
|
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
|
||||||
|
|
||||||
# Timm-backed objects
|
|
||||||
try:
|
|
||||||
if not (is_timm_available() and is_vision_available()):
|
|
||||||
raise OptionalDependencyNotAvailable()
|
|
||||||
except OptionalDependencyNotAvailable:
|
|
||||||
from .utils import dummy_timm_and_vision_objects
|
|
||||||
|
|
||||||
_import_structure["utils.dummy_timm_and_vision_objects"] = [
|
|
||||||
name for name in dir(dummy_timm_and_vision_objects) if not name.startswith("_")
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
_import_structure["models.deformable_detr"].extend(
|
|
||||||
[
|
|
||||||
"DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
|
|
||||||
"DeformableDetrForObjectDetection",
|
|
||||||
"DeformableDetrModel",
|
|
||||||
"DeformableDetrPreTrainedModel",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_import_structure["models.detr"].extend(
|
|
||||||
[
|
|
||||||
"DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
|
|
||||||
"DetrForObjectDetection",
|
|
||||||
"DetrForSegmentation",
|
|
||||||
"DetrModel",
|
|
||||||
"DetrPreTrainedModel",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_import_structure["models.table_transformer"].extend(
|
|
||||||
[
|
|
||||||
"TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
|
||||||
"TableTransformerForObjectDetection",
|
|
||||||
"TableTransformerModel",
|
|
||||||
"TableTransformerPreTrainedModel",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_import_structure["models.conditional_detr"].extend(
|
|
||||||
[
|
|
||||||
"CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
|
|
||||||
"ConditionalDetrForObjectDetection",
|
|
||||||
"ConditionalDetrForSegmentation",
|
|
||||||
"ConditionalDetrModel",
|
|
||||||
"ConditionalDetrPreTrainedModel",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# PyTorch-backed objects
|
# PyTorch-backed objects
|
||||||
try:
|
try:
|
||||||
@@ -1309,6 +1263,15 @@ else:
|
|||||||
"CodeGenPreTrainedModel",
|
"CodeGenPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.conditional_detr"].extend(
|
||||||
|
[
|
||||||
|
"CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"ConditionalDetrForObjectDetection",
|
||||||
|
"ConditionalDetrForSegmentation",
|
||||||
|
"ConditionalDetrModel",
|
||||||
|
"ConditionalDetrPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.convbert"].extend(
|
_import_structure["models.convbert"].extend(
|
||||||
[
|
[
|
||||||
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -1406,6 +1369,14 @@ else:
|
|||||||
"DecisionTransformerPreTrainedModel",
|
"DecisionTransformerPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.deformable_detr"].extend(
|
||||||
|
[
|
||||||
|
"DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"DeformableDetrForObjectDetection",
|
||||||
|
"DeformableDetrModel",
|
||||||
|
"DeformableDetrPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.deit"].extend(
|
_import_structure["models.deit"].extend(
|
||||||
[
|
[
|
||||||
"DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -1424,6 +1395,15 @@ else:
|
|||||||
"DetaPreTrainedModel",
|
"DetaPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.detr"].extend(
|
||||||
|
[
|
||||||
|
"DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"DetrForObjectDetection",
|
||||||
|
"DetrForSegmentation",
|
||||||
|
"DetrModel",
|
||||||
|
"DetrPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.dinat"].extend(
|
_import_structure["models.dinat"].extend(
|
||||||
[
|
[
|
||||||
"DINAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"DINAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -2372,6 +2352,14 @@ else:
|
|||||||
"load_tf_weights_in_t5",
|
"load_tf_weights_in_t5",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.table_transformer"].extend(
|
||||||
|
[
|
||||||
|
"TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"TableTransformerForObjectDetection",
|
||||||
|
"TableTransformerModel",
|
||||||
|
"TableTransformerPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.tapas"].extend(
|
_import_structure["models.tapas"].extend(
|
||||||
[
|
[
|
||||||
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -4398,39 +4386,6 @@ if TYPE_CHECKING:
|
|||||||
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
|
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
|
||||||
|
|
||||||
# Modeling
|
# Modeling
|
||||||
try:
|
|
||||||
if not (is_timm_available() and is_vision_available()):
|
|
||||||
raise OptionalDependencyNotAvailable()
|
|
||||||
except OptionalDependencyNotAvailable:
|
|
||||||
from .utils.dummy_timm_and_vision_objects import *
|
|
||||||
else:
|
|
||||||
from .models.conditional_detr import (
|
|
||||||
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
ConditionalDetrForObjectDetection,
|
|
||||||
ConditionalDetrForSegmentation,
|
|
||||||
ConditionalDetrModel,
|
|
||||||
ConditionalDetrPreTrainedModel,
|
|
||||||
)
|
|
||||||
from .models.deformable_detr import (
|
|
||||||
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
DeformableDetrForObjectDetection,
|
|
||||||
DeformableDetrModel,
|
|
||||||
DeformableDetrPreTrainedModel,
|
|
||||||
)
|
|
||||||
from .models.detr import (
|
|
||||||
DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
DetrForObjectDetection,
|
|
||||||
DetrForSegmentation,
|
|
||||||
DetrModel,
|
|
||||||
DetrPreTrainedModel,
|
|
||||||
)
|
|
||||||
from .models.table_transformer import (
|
|
||||||
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
TableTransformerForObjectDetection,
|
|
||||||
TableTransformerModel,
|
|
||||||
TableTransformerPreTrainedModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
@@ -4767,6 +4722,13 @@ if TYPE_CHECKING:
|
|||||||
CodeGenModel,
|
CodeGenModel,
|
||||||
CodeGenPreTrainedModel,
|
CodeGenPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.conditional_detr import (
|
||||||
|
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
ConditionalDetrForObjectDetection,
|
||||||
|
ConditionalDetrForSegmentation,
|
||||||
|
ConditionalDetrModel,
|
||||||
|
ConditionalDetrPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.convbert import (
|
from .models.convbert import (
|
||||||
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
ConvBertForMaskedLM,
|
ConvBertForMaskedLM,
|
||||||
@@ -4848,6 +4810,12 @@ if TYPE_CHECKING:
|
|||||||
DecisionTransformerModel,
|
DecisionTransformerModel,
|
||||||
DecisionTransformerPreTrainedModel,
|
DecisionTransformerPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.deformable_detr import (
|
||||||
|
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
DeformableDetrForObjectDetection,
|
||||||
|
DeformableDetrModel,
|
||||||
|
DeformableDetrPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.deit import (
|
from .models.deit import (
|
||||||
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
DeiTForImageClassification,
|
DeiTForImageClassification,
|
||||||
@@ -4862,6 +4830,13 @@ if TYPE_CHECKING:
|
|||||||
DetaModel,
|
DetaModel,
|
||||||
DetaPreTrainedModel,
|
DetaPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.detr import (
|
||||||
|
DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
DetrForObjectDetection,
|
||||||
|
DetrForSegmentation,
|
||||||
|
DetrModel,
|
||||||
|
DetrPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.dinat import (
|
from .models.dinat import (
|
||||||
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
DinatBackbone,
|
DinatBackbone,
|
||||||
@@ -5626,6 +5601,12 @@ if TYPE_CHECKING:
|
|||||||
T5PreTrainedModel,
|
T5PreTrainedModel,
|
||||||
load_tf_weights_in_t5,
|
load_tf_weights_in_t5,
|
||||||
)
|
)
|
||||||
|
from .models.table_transformer import (
|
||||||
|
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TableTransformerForObjectDetection,
|
||||||
|
TableTransformerModel,
|
||||||
|
TableTransformerPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.tapas import (
|
from .models.tapas import (
|
||||||
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TapasForMaskedLM,
|
TapasForMaskedLM,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@@ -35,7 +35,7 @@ else:
|
|||||||
_import_structure["image_processing_conditional_detr"] = ["ConditionalDetrImageProcessor"]
|
_import_structure["image_processing_conditional_detr"] = ["ConditionalDetrImageProcessor"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_timm_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
@@ -66,7 +66,7 @@ if TYPE_CHECKING:
|
|||||||
from .image_processing_conditional_detr import ConditionalDetrImageProcessor
|
from .image_processing_conditional_detr import ConditionalDetrImageProcessor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_timm_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1101,12 +1101,12 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
|||||||
images (`ImageInput`):
|
images (`ImageInput`):
|
||||||
Image or batch of images to preprocess.
|
Image or batch of images to preprocess.
|
||||||
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
||||||
List of annotations associated with the image or batch of images. If annotionation is for object
|
List of annotations associated with the image or batch of images. If annotation is for object
|
||||||
detection, the annotations should be a dictionary with the following keys:
|
detection, the annotations should be a dictionary with the following keys:
|
||||||
- "image_id" (`int`): The image id.
|
- "image_id" (`int`): The image id.
|
||||||
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
|
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
|
||||||
dictionary. An image can have no annotations, in which case the list should be empty.
|
dictionary. An image can have no annotations, in which case the list should be empty.
|
||||||
If annotionation is for segmentation, the annotations should be a dictionary with the following keys:
|
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
||||||
- "image_id" (`int`): The image id.
|
- "image_id" (`int`): The image id.
|
||||||
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
||||||
An image can have no segments, in which case the list should be empty.
|
An image can have no segments, in which case the list should be empty.
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@@ -31,7 +31,7 @@ else:
|
|||||||
_import_structure["image_processing_deformable_detr"] = ["DeformableDetrImageProcessor"]
|
_import_structure["image_processing_deformable_detr"] = ["DeformableDetrImageProcessor"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_timm_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
@@ -57,7 +57,7 @@ if TYPE_CHECKING:
|
|||||||
from .image_processing_deformable_detr import DeformableDetrImageProcessor
|
from .image_processing_deformable_detr import DeformableDetrImageProcessor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_timm_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1099,12 +1099,12 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
|||||||
images (`ImageInput`):
|
images (`ImageInput`):
|
||||||
Image or batch of images to preprocess.
|
Image or batch of images to preprocess.
|
||||||
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
||||||
List of annotations associated with the image or batch of images. If annotionation is for object
|
List of annotations associated with the image or batch of images. If annotation is for object
|
||||||
detection, the annotations should be a dictionary with the following keys:
|
detection, the annotations should be a dictionary with the following keys:
|
||||||
- "image_id" (`int`): The image id.
|
- "image_id" (`int`): The image id.
|
||||||
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
|
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
|
||||||
dictionary. An image can have no annotations, in which case the list should be empty.
|
dictionary. An image can have no annotations, in which case the list should be empty.
|
||||||
If annotionation is for segmentation, the annotations should be a dictionary with the following keys:
|
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
||||||
- "image_id" (`int`): The image id.
|
- "image_id" (`int`): The image id.
|
||||||
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
||||||
An image can have no segments, in which case the list should be empty.
|
An image can have no segments, in which case the list should be empty.
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrOnnxConfig"]}
|
_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrOnnxConfig"]}
|
||||||
@@ -29,7 +29,7 @@ else:
|
|||||||
_import_structure["image_processing_detr"] = ["DetrImageProcessor"]
|
_import_structure["image_processing_detr"] = ["DetrImageProcessor"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_timm_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
@@ -56,7 +56,7 @@ if TYPE_CHECKING:
|
|||||||
from .image_processing_detr import DetrImageProcessor
|
from .image_processing_detr import DetrImageProcessor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_timm_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -14,8 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" DETR model configuration"""
|
""" DETR model configuration"""
|
||||||
|
|
||||||
|
import copy
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Mapping
|
from typing import Dict, Mapping
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@@ -187,6 +188,8 @@ class DetrConfig(PretrainedConfig):
|
|||||||
backbone_model_type = backbone_config.get("model_type")
|
backbone_model_type = backbone_config.get("model_type")
|
||||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||||
backbone_config = config_class.from_dict(backbone_config)
|
backbone_config = config_class.from_dict(backbone_config)
|
||||||
|
# set timm attributes to None
|
||||||
|
dilation, backbone, use_pretrained_backbone = None, None, None
|
||||||
|
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
self.backbone_config = backbone_config
|
self.backbone_config = backbone_config
|
||||||
@@ -233,6 +236,28 @@ class DetrConfig(PretrainedConfig):
|
|||||||
def hidden_size(self) -> int:
|
def hidden_size(self) -> int:
|
||||||
return self.d_model
|
return self.d_model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
|
||||||
|
"""Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
|
||||||
|
Args:
|
||||||
|
backbone_config ([`PretrainedConfig`]):
|
||||||
|
The backbone configuration.
|
||||||
|
Returns:
|
||||||
|
[`DetrConfig`]: An instance of a configuration object
|
||||||
|
"""
|
||||||
|
return cls(backbone_config=backbone_config, **kwargs)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, any]:
|
||||||
|
"""
|
||||||
|
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
|
||||||
|
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||||
|
"""
|
||||||
|
output = copy.deepcopy(self.__dict__)
|
||||||
|
if output["backbone_config"] is not None:
|
||||||
|
output["backbone_config"] = self.backbone_config.to_dict()
|
||||||
|
output["model_type"] = self.__class__.model_type
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class DetrOnnxConfig(OnnxConfig):
|
class DetrOnnxConfig(OnnxConfig):
|
||||||
torch_onnx_minimum_version = version.parse("1.11")
|
torch_onnx_minimum_version = version.parse("1.11")
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2022 The HuggingFace Inc. team.
|
# Copyright 2023 The HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -33,16 +33,16 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def get_detr_config(model_name):
|
def get_detr_config(model_name):
|
||||||
config = DetrConfig(use_timm_backbone=False)
|
# initialize config
|
||||||
|
if "resnet-50" in model_name:
|
||||||
# set backbone attributes
|
backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-50")
|
||||||
if "resnet50" in model_name:
|
elif "resnet-101" in model_name:
|
||||||
pass
|
backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
|
||||||
elif "resnet101" in model_name:
|
|
||||||
config.backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Model name should include either resnet50 or resnet101")
|
raise ValueError("Model name should include either resnet50 or resnet101")
|
||||||
|
|
||||||
|
config = DetrConfig(use_timm_backbone=False, backbone_config=backbone_config)
|
||||||
|
|
||||||
# set label attributes
|
# set label attributes
|
||||||
is_panoptic = "panoptic" in model_name
|
is_panoptic = "panoptic" in model_name
|
||||||
if is_panoptic:
|
if is_panoptic:
|
||||||
@@ -286,7 +286,7 @@ def prepare_img():
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
|
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
|
||||||
"""
|
"""
|
||||||
Copy/paste/tweak model's weights to our DETR structure.
|
Copy/paste/tweak model's weights to our DETR structure.
|
||||||
"""
|
"""
|
||||||
@@ -295,8 +295,12 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
|
|||||||
config, is_panoptic = get_detr_config(model_name)
|
config, is_panoptic = get_detr_config(model_name)
|
||||||
|
|
||||||
# load original model from torch hub
|
# load original model from torch hub
|
||||||
|
model_name_to_original_name = {
|
||||||
|
"detr-resnet-50": "detr_resnet50",
|
||||||
|
"detr-resnet-101": "detr_resnet101",
|
||||||
|
}
|
||||||
logger.info(f"Converting model {model_name}...")
|
logger.info(f"Converting model {model_name}...")
|
||||||
detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval()
|
detr = torch.hub.load("facebookresearch/detr", model_name_to_original_name[model_name], pretrained=True).eval()
|
||||||
state_dict = detr.state_dict()
|
state_dict = detr.state_dict()
|
||||||
# rename keys
|
# rename keys
|
||||||
for src, dest in create_rename_keys(config):
|
for src, dest in create_rename_keys(config):
|
||||||
@@ -344,9 +348,6 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
|
|||||||
original_outputs = detr(pixel_values)
|
original_outputs = detr(pixel_values)
|
||||||
outputs = model(pixel_values)
|
outputs = model(pixel_values)
|
||||||
|
|
||||||
print("Logits:", outputs.logits[0, :3, :3])
|
|
||||||
print("Original logits:", original_outputs["pred_logits"][0, :3, :3])
|
|
||||||
|
|
||||||
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-3)
|
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-3)
|
||||||
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-3)
|
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-3)
|
||||||
if is_panoptic:
|
if is_panoptic:
|
||||||
@@ -360,15 +361,26 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
|
|||||||
model.save_pretrained(pytorch_dump_folder_path)
|
model.save_pretrained(pytorch_dump_folder_path)
|
||||||
processor.save_pretrained(pytorch_dump_folder_path)
|
processor.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
# Upload model and image processor to the hub
|
||||||
|
logger.info("Uploading PyTorch model and image processor to the hub...")
|
||||||
|
model.push_to_hub(f"nielsr/{model_name}")
|
||||||
|
processor.push_to_hub(f"nielsr/{model_name}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name", default="detr_resnet50", type=str, help="Name of the DETR model you'd like to convert."
|
"--model_name",
|
||||||
|
default="detr-resnet-50",
|
||||||
|
type=str,
|
||||||
|
choices=["detr-resnet-50", "detr-resnet-101"],
|
||||||
|
help="Name of the DETR model you'd like to convert.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
|
convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
||||||
|
|||||||
@@ -1065,12 +1065,12 @@ class DetrImageProcessor(BaseImageProcessor):
|
|||||||
images (`ImageInput`):
|
images (`ImageInput`):
|
||||||
Image or batch of images to preprocess.
|
Image or batch of images to preprocess.
|
||||||
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
|
||||||
List of annotations associated with the image or batch of images. If annotionation is for object
|
List of annotations associated with the image or batch of images. If annotation is for object
|
||||||
detection, the annotations should be a dictionary with the following keys:
|
detection, the annotations should be a dictionary with the following keys:
|
||||||
- "image_id" (`int`): The image id.
|
- "image_id" (`int`): The image id.
|
||||||
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
|
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
|
||||||
dictionary. An image can have no annotations, in which case the list should be empty.
|
dictionary. An image can have no annotations, in which case the list should be empty.
|
||||||
If annotionation is for segmentation, the annotations should be a dictionary with the following keys:
|
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
|
||||||
- "image_id" (`int`): The image id.
|
- "image_id" (`int`): The image id.
|
||||||
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
|
||||||
An image can have no segments, in which case the list should be empty.
|
An image can have no segments, in which case the list should be empty.
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@@ -26,7 +26,7 @@ _import_structure = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_timm_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
@@ -47,7 +47,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_timm_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -189,6 +189,8 @@ class TableTransformerConfig(PretrainedConfig):
|
|||||||
backbone_model_type = backbone_config.get("model_type")
|
backbone_model_type = backbone_config.get("model_type")
|
||||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||||
backbone_config = config_class.from_dict(backbone_config)
|
backbone_config = config_class.from_dict(backbone_config)
|
||||||
|
# set timm attributes to None
|
||||||
|
dilation, backbone, use_pretrained_backbone = None, None, None
|
||||||
|
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
self.backbone_config = backbone_config
|
self.backbone_config = backbone_config
|
||||||
|
|||||||
@@ -1661,6 +1661,37 @@ class CodeGenPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalDetrForObjectDetection(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalDetrForSegmentation(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalDetrModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalDetrPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -2073,6 +2104,30 @@ class DecisionTransformerPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class DeformableDetrForObjectDetection(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DeformableDetrModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DeformableDetrPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -2135,6 +2190,37 @@ class DetaPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class DetrForObjectDetection(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DetrForSegmentation(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DetrModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DetrPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -6040,6 +6126,30 @@ def load_tf_weights_in_t5(*args, **kwargs):
|
|||||||
requires_backends(load_tf_weights_in_t5, ["torch"])
|
requires_backends(load_tf_weights_in_t5, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class TableTransformerForObjectDetection(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class TableTransformerModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class TableTransformerPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,112 +0,0 @@
|
|||||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
|
||||||
from ..utils import DummyObject, requires_backends
|
|
||||||
|
|
||||||
|
|
||||||
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionalDetrForObjectDetection(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionalDetrForSegmentation(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionalDetrModel(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionalDetrPreTrainedModel(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
|
||||||
|
|
||||||
|
|
||||||
class DeformableDetrForObjectDetection(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
class DeformableDetrModel(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
class DeformableDetrPreTrainedModel(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
|
||||||
|
|
||||||
|
|
||||||
class DetrForObjectDetection(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
class DetrForSegmentation(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
class DetrModel(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
class DetrPreTrainedModel(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
|
||||||
|
|
||||||
|
|
||||||
class TableTransformerForObjectDetection(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
class TableTransformerModel(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
|
|
||||||
|
|
||||||
class TableTransformerPreTrainedModel(metaclass=DummyObject):
|
|
||||||
_backends = ["timm", "vision"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["timm", "vision"])
|
|
||||||
@@ -20,7 +20,7 @@ import math
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import DetrConfig, is_timm_available, is_vision_available
|
from transformers import DetrConfig, is_timm_available, is_vision_available
|
||||||
from transformers.testing_utils import require_timm, require_vision, slow, torch_device
|
from transformers.testing_utils import require_timm, require_torch, require_vision, slow, torch_device
|
||||||
from transformers.utils import cached_property
|
from transformers.utils import cached_property
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@@ -510,7 +510,7 @@ def prepare_img():
|
|||||||
@require_timm
|
@require_timm
|
||||||
@require_vision
|
@require_vision
|
||||||
@slow
|
@slow
|
||||||
class DetrModelIntegrationTests(unittest.TestCase):
|
class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
|
||||||
@cached_property
|
@cached_property
|
||||||
def default_feature_extractor(self):
|
def default_feature_extractor(self):
|
||||||
return DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50") if is_vision_available() else None
|
return DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50") if is_vision_available() else None
|
||||||
@@ -626,3 +626,33 @@ class DetrModelIntegrationTests(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(results["segmentation"][:3, :3], expected_slice_segmentation, atol=1e-4))
|
self.assertTrue(torch.allclose(results["segmentation"][:3, :3], expected_slice_segmentation, atol=1e-4))
|
||||||
self.assertTrue(len(results["segments_info"]), expected_number_of_segments)
|
self.assertTrue(len(results["segments_info"]), expected_number_of_segments)
|
||||||
self.assertDictEqual(results["segments_info"][0], expected_first_segment)
|
self.assertDictEqual(results["segments_info"][0], expected_first_segment)
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
class DetrModelIntegrationTests(unittest.TestCase):
|
||||||
|
@cached_property
|
||||||
|
def default_feature_extractor(self):
|
||||||
|
return (
|
||||||
|
DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
|
||||||
|
if is_vision_available()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_inference_no_head(self):
|
||||||
|
model = DetrModel.from_pretrained("facebook/detr-resnet-50", revision="no_timm").to(torch_device)
|
||||||
|
|
||||||
|
feature_extractor = self.default_feature_extractor
|
||||||
|
image = prepare_img()
|
||||||
|
encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**encoding)
|
||||||
|
|
||||||
|
expected_shape = torch.Size((1, 100, 256))
|
||||||
|
assert outputs.last_hidden_state.shape == expected_shape
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[0.0616, -0.5146, -0.4032], [-0.7629, -0.4934, -1.7153], [-0.4768, -0.6403, -0.7826]]
|
||||||
|
).to(torch_device)
|
||||||
|
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|||||||
Reference in New Issue
Block a user