[NAT, DiNAT] Add backbone class (#20654)
* Add first draft * Add out_features attribute to config * Add corresponding test * Add Dinat backbone * Add BackboneMixin * Add Backbone mixin, improve tests * Fix embeddings * Fix bug * Improve backbones * Fix Nat backbone tests * Fix Dinat backbone tests * Apply suggestions Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -1274,6 +1274,7 @@ else:
|
|||||||
_import_structure["models.dinat"].extend(
|
_import_structure["models.dinat"].extend(
|
||||||
[
|
[
|
||||||
"DINAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"DINAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"DinatBackbone",
|
||||||
"DinatForImageClassification",
|
"DinatForImageClassification",
|
||||||
"DinatModel",
|
"DinatModel",
|
||||||
"DinatPreTrainedModel",
|
"DinatPreTrainedModel",
|
||||||
@@ -1769,6 +1770,7 @@ else:
|
|||||||
_import_structure["models.nat"].extend(
|
_import_structure["models.nat"].extend(
|
||||||
[
|
[
|
||||||
"NAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"NAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"NatBackbone",
|
||||||
"NatForImageClassification",
|
"NatForImageClassification",
|
||||||
"NatModel",
|
"NatModel",
|
||||||
"NatPreTrainedModel",
|
"NatPreTrainedModel",
|
||||||
@@ -4388,6 +4390,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.dinat import (
|
from .models.dinat import (
|
||||||
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
DinatBackbone,
|
||||||
DinatForImageClassification,
|
DinatForImageClassification,
|
||||||
DinatModel,
|
DinatModel,
|
||||||
DinatPreTrainedModel,
|
DinatPreTrainedModel,
|
||||||
@@ -4784,6 +4787,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.nat import (
|
from .models.nat import (
|
||||||
NAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
NAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
NatBackbone,
|
||||||
NatForImageClassification,
|
NatForImageClassification,
|
||||||
NatModel,
|
NatModel,
|
||||||
NatPreTrainedModel,
|
NatPreTrainedModel,
|
||||||
|
|||||||
@@ -865,7 +865,9 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
|||||||
[
|
[
|
||||||
# Backbone mapping
|
# Backbone mapping
|
||||||
("bit", "BitBackbone"),
|
("bit", "BitBackbone"),
|
||||||
|
("dinat", "DinatBackbone"),
|
||||||
("maskformer-swin", "MaskFormerSwinBackbone"),
|
("maskformer-swin", "MaskFormerSwinBackbone"),
|
||||||
|
("nat", "NatBackbone"),
|
||||||
("resnet", "ResNetBackbone"),
|
("resnet", "ResNetBackbone"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ else:
|
|||||||
"DinatForImageClassification",
|
"DinatForImageClassification",
|
||||||
"DinatModel",
|
"DinatModel",
|
||||||
"DinatPreTrainedModel",
|
"DinatPreTrainedModel",
|
||||||
|
"DinatBackbone",
|
||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .modeling_dinat import (
|
from .modeling_dinat import (
|
||||||
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
DinatBackbone,
|
||||||
DinatForImageClassification,
|
DinatForImageClassification,
|
||||||
DinatModel,
|
DinatModel,
|
||||||
DinatPreTrainedModel,
|
DinatPreTrainedModel,
|
||||||
|
|||||||
@@ -72,6 +72,9 @@ class DinatConfig(PretrainedConfig):
|
|||||||
The epsilon used by the layer normalization layers.
|
The epsilon used by the layer normalization layers.
|
||||||
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
|
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
|
||||||
The initial value for the layer scale. Disabled if <=0.
|
The initial value for the layer scale. Disabled if <=0.
|
||||||
|
out_features (`List[str]`, *optional*):
|
||||||
|
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||||
|
(depending on how many stages the model has). Will default to the last stage if unset.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -113,6 +116,7 @@ class DinatConfig(PretrainedConfig):
|
|||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-5,
|
layer_norm_eps=1e-5,
|
||||||
layer_scale_init_value=0.0,
|
layer_scale_init_value=0.0,
|
||||||
|
out_features=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -138,3 +142,13 @@ class DinatConfig(PretrainedConfig):
|
|||||||
# this indicates the channel dimension after the last stage of the model
|
# this indicates the channel dimension after the last stage of the model
|
||||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||||
self.layer_scale_init_value = layer_scale_init_value
|
self.layer_scale_init_value = layer_scale_init_value
|
||||||
|
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||||
|
if out_features is not None:
|
||||||
|
if not isinstance(out_features, list):
|
||||||
|
raise ValueError("out_features should be a list")
|
||||||
|
for feature in out_features:
|
||||||
|
if feature not in self.stage_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||||
|
)
|
||||||
|
self.out_features = out_features
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_outputs import BackboneOutput
|
||||||
|
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@@ -35,6 +36,7 @@ from ...utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_natten_available,
|
is_natten_available,
|
||||||
logging,
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
from .configuration_dinat import DinatConfig
|
from .configuration_dinat import DinatConfig
|
||||||
@@ -555,14 +557,11 @@ class DinatStage(nn.Module):
|
|||||||
layer_outputs = layer_module(hidden_states, output_attentions)
|
layer_outputs = layer_module(hidden_states, output_attentions)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
hidden_states_before_downsampling = hidden_states
|
||||||
if self.downsample is not None:
|
if self.downsample is not None:
|
||||||
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
hidden_states = self.downsample(hidden_states_before_downsampling)
|
||||||
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
|
||||||
hidden_states = self.downsample(layer_outputs[0])
|
|
||||||
else:
|
|
||||||
output_dimensions = (height, width, height, width)
|
|
||||||
|
|
||||||
stage_outputs = (hidden_states, output_dimensions)
|
stage_outputs = (hidden_states, hidden_states_before_downsampling)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
stage_outputs += layer_outputs[1:]
|
stage_outputs += layer_outputs[1:]
|
||||||
@@ -596,6 +595,7 @@ class DinatEncoder(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states: Optional[bool] = False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
|
output_hidden_states_before_downsampling: Optional[bool] = False,
|
||||||
return_dict: Optional[bool] = True,
|
return_dict: Optional[bool] = True,
|
||||||
) -> Union[Tuple, DinatEncoderOutput]:
|
) -> Union[Tuple, DinatEncoderOutput]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
@@ -612,8 +612,14 @@ class DinatEncoder(nn.Module):
|
|||||||
layer_outputs = layer_module(hidden_states, output_attentions)
|
layer_outputs = layer_module(hidden_states, output_attentions)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
hidden_states_before_downsampling = layer_outputs[1]
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states and output_hidden_states_before_downsampling:
|
||||||
|
# rearrange b h w c -> b c h w
|
||||||
|
reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
|
||||||
|
all_hidden_states += (hidden_states_before_downsampling,)
|
||||||
|
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||||
|
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
||||||
# rearrange b h w c -> b c h w
|
# rearrange b h w c -> b c h w
|
||||||
reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
|
reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
@@ -871,3 +877,120 @@ class DinatForImageClassification(DinatPreTrainedModel):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"NAT backbone, to be used with frameworks like DETR and MaskFormer.",
|
||||||
|
DINAT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
requires_backends(self, ["natten"])
|
||||||
|
|
||||||
|
self.stage_names = config.stage_names
|
||||||
|
|
||||||
|
self.embeddings = DinatEmbeddings(config)
|
||||||
|
self.encoder = DinatEncoder(config)
|
||||||
|
|
||||||
|
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||||
|
|
||||||
|
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||||
|
self.out_feature_channels = {}
|
||||||
|
self.out_feature_channels["stem"] = config.embed_dim
|
||||||
|
for i, stage in enumerate(self.stage_names[1:]):
|
||||||
|
self.out_feature_channels[stage] = num_features[i]
|
||||||
|
|
||||||
|
# Add layer norms to hidden states of out_features
|
||||||
|
hidden_states_norms = dict()
|
||||||
|
for stage, num_channels in zip(self.out_features, self.channels):
|
||||||
|
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
|
||||||
|
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.patch_embeddings
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self):
|
||||||
|
return [self.out_feature_channels[name] for name in self.out_features]
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> BackboneOutput:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoImageProcessor, AutoBackbone
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
|
||||||
|
>>> model = AutoBackbone.from_pretrained(
|
||||||
|
... "shi-labs/nat-mini-in1k-2240", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> inputs = processor(image, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> feature_maps = outputs.feature_maps
|
||||||
|
>>> list(feature_maps[-1].shape)
|
||||||
|
[1, 2048, 7, 7]
|
||||||
|
```"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
|
||||||
|
embedding_output = self.embeddings(pixel_values)
|
||||||
|
|
||||||
|
outputs = self.encoder(
|
||||||
|
embedding_output,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_hidden_states_before_downsampling=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs.reshaped_hidden_states
|
||||||
|
|
||||||
|
feature_maps = ()
|
||||||
|
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
||||||
|
if stage in self.out_features:
|
||||||
|
batch_size, num_channels, height, width = hidden_state.shape
|
||||||
|
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
|
||||||
|
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
|
||||||
|
hidden_state = self.hidden_states_norms[stage](hidden_state)
|
||||||
|
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
|
||||||
|
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
||||||
|
feature_maps += (hidden_state,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (feature_maps,)
|
||||||
|
if output_hidden_states:
|
||||||
|
output += (outputs.hidden_states,)
|
||||||
|
return output
|
||||||
|
|
||||||
|
return BackboneOutput(
|
||||||
|
feature_maps=feature_maps,
|
||||||
|
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ else:
|
|||||||
"NatForImageClassification",
|
"NatForImageClassification",
|
||||||
"NatModel",
|
"NatModel",
|
||||||
"NatPreTrainedModel",
|
"NatPreTrainedModel",
|
||||||
|
"NatBackbone",
|
||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .modeling_nat import (
|
from .modeling_nat import (
|
||||||
NAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
NAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
NatBackbone,
|
||||||
NatForImageClassification,
|
NatForImageClassification,
|
||||||
NatModel,
|
NatModel,
|
||||||
NatPreTrainedModel,
|
NatPreTrainedModel,
|
||||||
|
|||||||
@@ -70,6 +70,9 @@ class NatConfig(PretrainedConfig):
|
|||||||
The epsilon used by the layer normalization layers.
|
The epsilon used by the layer normalization layers.
|
||||||
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
|
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
|
||||||
The initial value for the layer scale. Disabled if <=0.
|
The initial value for the layer scale. Disabled if <=0.
|
||||||
|
out_features (`List[str]`, *optional*):
|
||||||
|
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||||
|
(depending on how many stages the model has). Will default to the last stage if unset.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -110,6 +113,7 @@ class NatConfig(PretrainedConfig):
|
|||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-5,
|
layer_norm_eps=1e-5,
|
||||||
layer_scale_init_value=0.0,
|
layer_scale_init_value=0.0,
|
||||||
|
out_features=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -134,3 +138,13 @@ class NatConfig(PretrainedConfig):
|
|||||||
# this indicates the channel dimension after the last stage of the model
|
# this indicates the channel dimension after the last stage of the model
|
||||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||||
self.layer_scale_init_value = layer_scale_init_value
|
self.layer_scale_init_value = layer_scale_init_value
|
||||||
|
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||||
|
if out_features is not None:
|
||||||
|
if not isinstance(out_features, list):
|
||||||
|
raise ValueError("out_features should be a list")
|
||||||
|
for feature in out_features:
|
||||||
|
if feature not in self.stage_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||||
|
)
|
||||||
|
self.out_features = out_features
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_outputs import BackboneOutput
|
||||||
|
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@@ -35,6 +36,7 @@ from ...utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_natten_available,
|
is_natten_available,
|
||||||
logging,
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
from .configuration_nat import NatConfig
|
from .configuration_nat import NatConfig
|
||||||
@@ -536,14 +538,11 @@ class NatStage(nn.Module):
|
|||||||
layer_outputs = layer_module(hidden_states, output_attentions)
|
layer_outputs = layer_module(hidden_states, output_attentions)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
hidden_states_before_downsampling = hidden_states
|
||||||
if self.downsample is not None:
|
if self.downsample is not None:
|
||||||
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
hidden_states = self.downsample(hidden_states_before_downsampling)
|
||||||
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
|
||||||
hidden_states = self.downsample(layer_outputs[0])
|
|
||||||
else:
|
|
||||||
output_dimensions = (height, width, height, width)
|
|
||||||
|
|
||||||
stage_outputs = (hidden_states, output_dimensions)
|
stage_outputs = (hidden_states, hidden_states_before_downsampling)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
stage_outputs += layer_outputs[1:]
|
stage_outputs += layer_outputs[1:]
|
||||||
@@ -575,6 +574,7 @@ class NatEncoder(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states: Optional[bool] = False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
|
output_hidden_states_before_downsampling: Optional[bool] = False,
|
||||||
return_dict: Optional[bool] = True,
|
return_dict: Optional[bool] = True,
|
||||||
) -> Union[Tuple, NatEncoderOutput]:
|
) -> Union[Tuple, NatEncoderOutput]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
@@ -591,8 +591,14 @@ class NatEncoder(nn.Module):
|
|||||||
layer_outputs = layer_module(hidden_states, output_attentions)
|
layer_outputs = layer_module(hidden_states, output_attentions)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
hidden_states_before_downsampling = layer_outputs[1]
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states and output_hidden_states_before_downsampling:
|
||||||
|
# rearrange b h w c -> b c h w
|
||||||
|
reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
|
||||||
|
all_hidden_states += (hidden_states_before_downsampling,)
|
||||||
|
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||||||
|
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
||||||
# rearrange b h w c -> b c h w
|
# rearrange b h w c -> b c h w
|
||||||
reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
|
reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
@@ -849,3 +855,121 @@ class NatForImageClassification(NatPreTrainedModel):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"NAT backbone, to be used with frameworks like DETR and MaskFormer.",
|
||||||
|
NAT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class NatBackbone(NatPreTrainedModel, BackboneMixin):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
requires_backends(self, ["natten"])
|
||||||
|
|
||||||
|
self.stage_names = config.stage_names
|
||||||
|
|
||||||
|
self.embeddings = NatEmbeddings(config)
|
||||||
|
self.encoder = NatEncoder(config)
|
||||||
|
|
||||||
|
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||||
|
|
||||||
|
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||||
|
self.out_feature_channels = {}
|
||||||
|
self.out_feature_channels["stem"] = config.embed_dim
|
||||||
|
for i, stage in enumerate(self.stage_names[1:]):
|
||||||
|
self.out_feature_channels[stage] = num_features[i]
|
||||||
|
|
||||||
|
# Add layer norms to hidden states of out_features
|
||||||
|
hidden_states_norms = dict()
|
||||||
|
for stage, num_channels in zip(self.out_features, self.channels):
|
||||||
|
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
|
||||||
|
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.patch_embeddings
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self):
|
||||||
|
return [self.out_feature_channels[name] for name in self.out_features]
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> BackboneOutput:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoImageProcessor, AutoBackbone
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
|
||||||
|
>>> model = AutoBackbone.from_pretrained(
|
||||||
|
... "shi-labs/nat-mini-in1k-2240", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> inputs = processor(image, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> feature_maps = outputs.feature_maps
|
||||||
|
>>> list(feature_maps[-1].shape)
|
||||||
|
[1, 2048, 7, 7]
|
||||||
|
```"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
|
||||||
|
embedding_output = self.embeddings(pixel_values)
|
||||||
|
|
||||||
|
outputs = self.encoder(
|
||||||
|
embedding_output,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_hidden_states_before_downsampling=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs.reshaped_hidden_states
|
||||||
|
|
||||||
|
feature_maps = ()
|
||||||
|
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
||||||
|
if stage in self.out_features:
|
||||||
|
# TODO can we simplify this?
|
||||||
|
batch_size, num_channels, height, width = hidden_state.shape
|
||||||
|
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
|
||||||
|
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
|
||||||
|
hidden_state = self.hidden_states_norms[stage](hidden_state)
|
||||||
|
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
|
||||||
|
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
||||||
|
feature_maps += (hidden_state,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (feature_maps,)
|
||||||
|
if output_hidden_states:
|
||||||
|
output += (outputs.hidden_states,)
|
||||||
|
return output
|
||||||
|
|
||||||
|
return BackboneOutput(
|
||||||
|
feature_maps=feature_maps,
|
||||||
|
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1856,6 +1856,13 @@ class DeiTPreTrainedModel(metaclass=DummyObject):
|
|||||||
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class DinatBackbone(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class DinatForImageClassification(metaclass=DummyObject):
|
class DinatForImageClassification(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -3913,6 +3920,13 @@ class MvpPreTrainedModel(metaclass=DummyObject):
|
|||||||
NAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
NAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class NatBackbone(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class NatForImageClassification(metaclass=DummyObject):
|
class NatForImageClassification(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import DinatForImageClassification, DinatModel
|
from transformers import DinatBackbone, DinatForImageClassification, DinatModel
|
||||||
from transformers.models.dinat.modeling_dinat import DINAT_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.dinat.modeling_dinat import DINAT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -64,8 +64,8 @@ class DinatModelTester:
|
|||||||
is_training=True,
|
is_training=True,
|
||||||
scope=None,
|
scope=None,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
type_sequence_label_size=10,
|
num_labels=10,
|
||||||
encoder_stride=8,
|
out_features=["stage1", "stage2"],
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -89,15 +89,15 @@ class DinatModelTester:
|
|||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
self.use_labels = use_labels
|
self.use_labels = use_labels
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
self.num_labels = num_labels
|
||||||
self.encoder_stride = encoder_stride
|
self.out_features = out_features
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
|
||||||
labels = None
|
labels = None
|
||||||
if self.use_labels:
|
if self.use_labels:
|
||||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
labels = ids_tensor([self.batch_size], self.num_labels)
|
||||||
|
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
|
|
||||||
@@ -105,6 +105,7 @@ class DinatModelTester:
|
|||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return DinatConfig(
|
return DinatConfig(
|
||||||
|
num_labels=self.num_labels,
|
||||||
image_size=self.image_size,
|
image_size=self.image_size,
|
||||||
patch_size=self.patch_size,
|
patch_size=self.patch_size,
|
||||||
num_channels=self.num_channels,
|
num_channels=self.num_channels,
|
||||||
@@ -122,7 +123,7 @@ class DinatModelTester:
|
|||||||
patch_norm=self.patch_norm,
|
patch_norm=self.patch_norm,
|
||||||
layer_norm_eps=self.layer_norm_eps,
|
layer_norm_eps=self.layer_norm_eps,
|
||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
encoder_stride=self.encoder_stride,
|
out_features=self.out_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values, labels):
|
def create_and_check_model(self, config, pixel_values, labels):
|
||||||
@@ -139,12 +140,11 @@ class DinatModelTester:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
config.num_labels = self.type_sequence_label_size
|
|
||||||
model = DinatForImageClassification(config)
|
model = DinatForImageClassification(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values, labels=labels)
|
result = model(pixel_values, labels=labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||||
|
|
||||||
# test greyscale images
|
# test greyscale images
|
||||||
config.num_channels = 1
|
config.num_channels = 1
|
||||||
@@ -154,7 +154,34 @@ class DinatModelTester:
|
|||||||
|
|
||||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||||
|
|
||||||
|
def create_and_check_backbone(self, config, pixel_values, labels):
|
||||||
|
model = DinatBackbone(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
|
# verify hidden states
|
||||||
|
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||||
|
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
|
||||||
|
|
||||||
|
# verify channels
|
||||||
|
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||||
|
|
||||||
|
# verify backbone works with out_features=None
|
||||||
|
config.out_features = None
|
||||||
|
model = DinatBackbone(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
|
# verify feature maps
|
||||||
|
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||||
|
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
|
||||||
|
|
||||||
|
# verify channels
|
||||||
|
self.parent.assertEqual(len(model.channels), 1)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
@@ -167,7 +194,15 @@ class DinatModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class DinatModelTest(ModelTesterMixin, unittest.TestCase):
|
class DinatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (DinatModel, DinatForImageClassification) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
DinatModel,
|
||||||
|
DinatForImageClassification,
|
||||||
|
DinatBackbone,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
|
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
@@ -199,8 +234,16 @@ class DinatModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_backbone(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_backbone(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Dinat does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
# Dinat does not use inputs_embeds
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Dinat does not use feedforward chunking")
|
||||||
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
@@ -257,17 +300,18 @@ class DinatModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
[height, width, self.model_tester.embed_dim],
|
[height, width, self.model_tester.embed_dim],
|
||||||
)
|
)
|
||||||
|
|
||||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
if model_class.__name__ != "DinatBackbone":
|
||||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||||
|
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||||
|
|
||||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||||
reshaped_hidden_states = (
|
reshaped_hidden_states = (
|
||||||
reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)
|
reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)
|
||||||
)
|
)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(reshaped_hidden_states.shape[-3:]),
|
list(reshaped_hidden_states.shape[-3:]),
|
||||||
[height, width, self.model_tester.embed_dim],
|
[height, width, self.model_tester.embed_dim],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import NatForImageClassification, NatModel
|
from transformers import NatBackbone, NatForImageClassification, NatModel
|
||||||
from transformers.models.nat.modeling_nat import NAT_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.nat.modeling_nat import NAT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -63,8 +63,8 @@ class NatModelTester:
|
|||||||
is_training=True,
|
is_training=True,
|
||||||
scope=None,
|
scope=None,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
type_sequence_label_size=10,
|
num_labels=10,
|
||||||
encoder_stride=8,
|
out_features=["stage1", "stage2"],
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -87,15 +87,15 @@ class NatModelTester:
|
|||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
self.use_labels = use_labels
|
self.use_labels = use_labels
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
self.num_labels = num_labels
|
||||||
self.encoder_stride = encoder_stride
|
self.out_features = out_features
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
|
||||||
labels = None
|
labels = None
|
||||||
if self.use_labels:
|
if self.use_labels:
|
||||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
labels = ids_tensor([self.batch_size], self.num_labels)
|
||||||
|
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
|
|
||||||
@@ -103,6 +103,7 @@ class NatModelTester:
|
|||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return NatConfig(
|
return NatConfig(
|
||||||
|
num_labels=self.num_labels,
|
||||||
image_size=self.image_size,
|
image_size=self.image_size,
|
||||||
patch_size=self.patch_size,
|
patch_size=self.patch_size,
|
||||||
num_channels=self.num_channels,
|
num_channels=self.num_channels,
|
||||||
@@ -119,7 +120,7 @@ class NatModelTester:
|
|||||||
patch_norm=self.patch_norm,
|
patch_norm=self.patch_norm,
|
||||||
layer_norm_eps=self.layer_norm_eps,
|
layer_norm_eps=self.layer_norm_eps,
|
||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
encoder_stride=self.encoder_stride,
|
out_features=self.out_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values, labels):
|
def create_and_check_model(self, config, pixel_values, labels):
|
||||||
@@ -136,12 +137,11 @@ class NatModelTester:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
config.num_labels = self.type_sequence_label_size
|
|
||||||
model = NatForImageClassification(config)
|
model = NatForImageClassification(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values, labels=labels)
|
result = model(pixel_values, labels=labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||||
|
|
||||||
# test greyscale images
|
# test greyscale images
|
||||||
config.num_channels = 1
|
config.num_channels = 1
|
||||||
@@ -151,7 +151,34 @@ class NatModelTester:
|
|||||||
|
|
||||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||||
|
|
||||||
|
def create_and_check_backbone(self, config, pixel_values, labels):
|
||||||
|
model = NatBackbone(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
|
# verify hidden states
|
||||||
|
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||||
|
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
|
||||||
|
|
||||||
|
# verify channels
|
||||||
|
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||||
|
|
||||||
|
# verify backbone works with out_features=None
|
||||||
|
config.out_features = None
|
||||||
|
model = NatBackbone(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
|
# verify feature maps
|
||||||
|
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||||
|
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
|
||||||
|
|
||||||
|
# verify channels
|
||||||
|
self.parent.assertEqual(len(model.channels), 1)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
@@ -164,7 +191,15 @@ class NatModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class NatModelTest(ModelTesterMixin, unittest.TestCase):
|
class NatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (NatModel, NatForImageClassification) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
NatModel,
|
||||||
|
NatForImageClassification,
|
||||||
|
NatBackbone,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
|
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
@@ -196,8 +231,16 @@ class NatModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_backbone(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_backbone(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Nat does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
# Nat does not use inputs_embeds
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Nat does not use feedforward chunking")
|
||||||
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
@@ -254,17 +297,18 @@ class NatModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
[height, width, self.model_tester.embed_dim],
|
[height, width, self.model_tester.embed_dim],
|
||||||
)
|
)
|
||||||
|
|
||||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
if model_class.__name__ != "NatBackbone":
|
||||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||||
|
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||||
|
|
||||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||||
reshaped_hidden_states = (
|
reshaped_hidden_states = (
|
||||||
reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)
|
reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)
|
||||||
)
|
)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(reshaped_hidden_states.shape[-3:]),
|
list(reshaped_hidden_states.shape[-3:]),
|
||||||
[height, width, self.model_tester.embed_dim],
|
[height, width, self.model_tester.embed_dim],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -677,6 +677,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
|||||||
"MaskFormerSwinBackbone",
|
"MaskFormerSwinBackbone",
|
||||||
"ResNetBackbone",
|
"ResNetBackbone",
|
||||||
"AutoBackbone",
|
"AutoBackbone",
|
||||||
|
"DinatBackbone",
|
||||||
|
"NatBackbone",
|
||||||
"MaskFormerSwinConfig",
|
"MaskFormerSwinConfig",
|
||||||
"MaskFormerSwinModel",
|
"MaskFormerSwinModel",
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user